xref: /llvm-project/mlir/lib/Bindings/Python/IRCore.cpp (revision acde3f722ff3766f6f793884108d342b78623fe4)
1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <optional>
10 #include <utility>
11 
12 #include "Globals.h"
13 #include "IRModule.h"
14 #include "NanobindUtils.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/Debug.h"
17 #include "mlir-c/Diagnostics.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Support.h"
20 #include "mlir/Bindings/Python/Nanobind.h"
21 #include "mlir/Bindings/Python/NanobindAdaptors.h"
22 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/SmallVector.h"
25 
26 namespace nb = nanobind;
27 using namespace nb::literals;
28 using namespace mlir;
29 using namespace mlir::python;
30 
31 using llvm::SmallVector;
32 using llvm::StringRef;
33 using llvm::Twine;
34 
35 //------------------------------------------------------------------------------
36 // Docstrings (trivial, non-duplicated docstrings are included inline).
37 //------------------------------------------------------------------------------
38 
39 static const char kContextParseTypeDocstring[] =
40     R"(Parses the assembly form of a type.
41 
42 Returns a Type object or raises an MLIRError if the type cannot be parsed.
43 
44 See also: https://mlir.llvm.org/docs/LangRef/#type-system
45 )";
46 
47 static const char kContextGetCallSiteLocationDocstring[] =
48     R"(Gets a Location representing a caller and callsite)";
49 
50 static const char kContextGetFileLocationDocstring[] =
51     R"(Gets a Location representing a file, line and column)";
52 
53 static const char kContextGetFileRangeDocstring[] =
54     R"(Gets a Location representing a file, line and column range)";
55 
56 static const char kContextGetFusedLocationDocstring[] =
57     R"(Gets a Location representing a fused location with optional metadata)";
58 
59 static const char kContextGetNameLocationDocString[] =
60     R"(Gets a Location representing a named location with optional child location)";
61 
62 static const char kModuleParseDocstring[] =
63     R"(Parses a module's assembly format from a string.
64 
65 Returns a new MlirModule or raises an MLIRError if the parsing fails.
66 
67 See also: https://mlir.llvm.org/docs/LangRef/
68 )";
69 
70 static const char kOperationCreateDocstring[] =
71     R"(Creates a new operation.
72 
73 Args:
74   name: Operation name (e.g. "dialect.operation").
75   results: Sequence of Type representing op result types.
76   attributes: Dict of str:Attribute.
77   successors: List of Block for the operation's successors.
78   regions: Number of regions to create.
79   location: A Location object (defaults to resolve from context manager).
80   ip: An InsertionPoint (defaults to resolve from context manager or set to
81     False to disable insertion, even with an insertion point set in the
82     context manager).
83   infer_type: Whether to infer result types.
84 Returns:
85   A new "detached" Operation object. Detached operations can be added
86   to blocks, which causes them to become "attached."
87 )";
88 
89 static const char kOperationPrintDocstring[] =
90     R"(Prints the assembly form of the operation to a file like object.
91 
92 Args:
93   file: The file like object to write to. Defaults to sys.stdout.
94   binary: Whether to write bytes (True) or str (False). Defaults to False.
95   large_elements_limit: Whether to elide elements attributes above this
96     number of elements. Defaults to None (no limit).
97   enable_debug_info: Whether to print debug/location information. Defaults
98     to False.
99   pretty_debug_info: Whether to format debug information for easier reading
100     by a human (warning: the result is unparseable).
101   print_generic_op_form: Whether to print the generic assembly forms of all
102     ops. Defaults to False.
103   use_local_Scope: Whether to print in a way that is more optimized for
104     multi-threaded access but may not be consistent with how the overall
105     module prints.
106   assume_verified: By default, if not printing generic form, the verifier
107     will be run and if it fails, generic form will be printed with a comment
108     about failed verification. While a reasonable default for interactive use,
109     for systematic use, it is often better for the caller to verify explicitly
110     and report failures in a more robust fashion. Set this to True if doing this
111     in order to avoid running a redundant verification. If the IR is actually
112     invalid, behavior is undefined.
113   skip_regions: Whether to skip printing regions. Defaults to False.
114 )";
115 
116 static const char kOperationPrintStateDocstring[] =
117     R"(Prints the assembly form of the operation to a file like object.
118 
119 Args:
120   file: The file like object to write to. Defaults to sys.stdout.
121   binary: Whether to write bytes (True) or str (False). Defaults to False.
122   state: AsmState capturing the operation numbering and flags.
123 )";
124 
125 static const char kOperationGetAsmDocstring[] =
126     R"(Gets the assembly form of the operation with all options available.
127 
128 Args:
129   binary: Whether to return a bytes (True) or str (False) object. Defaults to
130     False.
131   ... others ...: See the print() method for common keyword arguments for
132     configuring the printout.
133 Returns:
134   Either a bytes or str object, depending on the setting of the 'binary'
135   argument.
136 )";
137 
138 static const char kOperationPrintBytecodeDocstring[] =
139     R"(Write the bytecode form of the operation to a file like object.
140 
141 Args:
142   file: The file like object to write to.
143   desired_version: The version of bytecode to emit.
144 Returns:
145   The bytecode writer status.
146 )";
147 
148 static const char kOperationStrDunderDocstring[] =
149     R"(Gets the assembly form of the operation with default options.
150 
151 If more advanced control over the assembly formatting or I/O options is needed,
152 use the dedicated print or get_asm method, which supports keyword arguments to
153 customize behavior.
154 )";
155 
156 static const char kDumpDocstring[] =
157     R"(Dumps a debug representation of the object to stderr.)";
158 
159 static const char kAppendBlockDocstring[] =
160     R"(Appends a new block, with argument types as positional args.
161 
162 Returns:
163   The created block.
164 )";
165 
166 static const char kValueDunderStrDocstring[] =
167     R"(Returns the string form of the value.
168 
169 If the value is a block argument, this is the assembly form of its type and the
170 position in the argument list. If the value is an operation result, this is
171 equivalent to printing the operation that produced it.
172 )";
173 
174 static const char kGetNameAsOperand[] =
175     R"(Returns the string form of value as an operand (i.e., the ValueID).
176 )";
177 
178 static const char kValueReplaceAllUsesWithDocstring[] =
179     R"(Replace all uses of value with the new value, updating anything in
180 the IR that uses 'self' to use the other value instead.
181 )";
182 
183 static const char kValueReplaceAllUsesExceptDocstring[] =
184     R"("Replace all uses of this value with the 'with' value, except for those
185 in 'exceptions'. 'exceptions' can be either a single operation or a list of
186 operations.
187 )";
188 
189 //------------------------------------------------------------------------------
190 // Utilities.
191 //------------------------------------------------------------------------------
192 
193 /// Helper for creating an @classmethod.
194 template <class Func, typename... Args>
195 nb::object classmethod(Func f, Args... args) {
196   nb::object cf = nb::cpp_function(f, args...);
197   return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
198 }
199 
200 static nb::object
201 createCustomDialectWrapper(const std::string &dialectNamespace,
202                            nb::object dialectDescriptor) {
203   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
204   if (!dialectClass) {
205     // Use the base class.
206     return nb::cast(PyDialect(std::move(dialectDescriptor)));
207   }
208 
209   // Create the custom implementation.
210   return (*dialectClass)(std::move(dialectDescriptor));
211 }
212 
213 static MlirStringRef toMlirStringRef(const std::string &s) {
214   return mlirStringRefCreate(s.data(), s.size());
215 }
216 
217 static MlirStringRef toMlirStringRef(std::string_view s) {
218   return mlirStringRefCreate(s.data(), s.size());
219 }
220 
221 static MlirStringRef toMlirStringRef(const nb::bytes &s) {
222   return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
223 }
224 
225 /// Create a block, using the current location context if no locations are
226 /// specified.
227 static MlirBlock createBlock(const nb::sequence &pyArgTypes,
228                              const std::optional<nb::sequence> &pyArgLocs) {
229   SmallVector<MlirType> argTypes;
230   argTypes.reserve(nb::len(pyArgTypes));
231   for (const auto &pyType : pyArgTypes)
232     argTypes.push_back(nb::cast<PyType &>(pyType));
233 
234   SmallVector<MlirLocation> argLocs;
235   if (pyArgLocs) {
236     argLocs.reserve(nb::len(*pyArgLocs));
237     for (const auto &pyLoc : *pyArgLocs)
238       argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
239   } else if (!argTypes.empty()) {
240     argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
241   }
242 
243   if (argTypes.size() != argLocs.size())
244     throw nb::value_error(("Expected " + Twine(argTypes.size()) +
245                            " locations, got: " + Twine(argLocs.size()))
246                               .str()
247                               .c_str());
248   return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
249 }
250 
251 /// Wrapper for the global LLVM debugging flag.
252 struct PyGlobalDebugFlag {
253   static void set(nb::object &o, bool enable) {
254     nb::ft_lock_guard lock(mutex);
255     mlirEnableGlobalDebug(enable);
256   }
257 
258   static bool get(const nb::object &) {
259     nb::ft_lock_guard lock(mutex);
260     return mlirIsGlobalDebugEnabled();
261   }
262 
263   static void bind(nb::module_ &m) {
264     // Debug flags.
265     nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
266         .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
267                             &PyGlobalDebugFlag::set, "LLVM-wide debug flag")
268         .def_static(
269             "set_types",
270             [](const std::string &type) {
271               nb::ft_lock_guard lock(mutex);
272               mlirSetGlobalDebugType(type.c_str());
273             },
274             "types"_a, "Sets specific debug types to be produced by LLVM")
275         .def_static("set_types", [](const std::vector<std::string> &types) {
276           std::vector<const char *> pointers;
277           pointers.reserve(types.size());
278           for (const std::string &str : types)
279             pointers.push_back(str.c_str());
280           nb::ft_lock_guard lock(mutex);
281           mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
282         });
283   }
284 
285 private:
286   static nb::ft_mutex mutex;
287 };
288 
289 nb::ft_mutex PyGlobalDebugFlag::mutex;
290 
291 struct PyAttrBuilderMap {
292   static bool dunderContains(const std::string &attributeKind) {
293     return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
294   }
295   static nb::callable dunderGetItemNamed(const std::string &attributeKind) {
296     auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
297     if (!builder)
298       throw nb::key_error(attributeKind.c_str());
299     return *builder;
300   }
301   static void dunderSetItemNamed(const std::string &attributeKind,
302                                 nb::callable func, bool replace) {
303     PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
304                                               replace);
305   }
306 
307   static void bind(nb::module_ &m) {
308     nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
309         .def_static("contains", &PyAttrBuilderMap::dunderContains)
310         .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed)
311         .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
312                     "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
313                     "Register an attribute builder for building MLIR "
314                     "attributes from python values.");
315   }
316 };
317 
318 //------------------------------------------------------------------------------
319 // PyBlock
320 //------------------------------------------------------------------------------
321 
322 nb::object PyBlock::getCapsule() {
323   return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
324 }
325 
326 //------------------------------------------------------------------------------
327 // Collections.
328 //------------------------------------------------------------------------------
329 
330 namespace {
331 
332 class PyRegionIterator {
333 public:
334   PyRegionIterator(PyOperationRef operation)
335       : operation(std::move(operation)) {}
336 
337   PyRegionIterator &dunderIter() { return *this; }
338 
339   PyRegion dunderNext() {
340     operation->checkValid();
341     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
342       throw nb::stop_iteration();
343     }
344     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
345     return PyRegion(operation, region);
346   }
347 
348   static void bind(nb::module_ &m) {
349     nb::class_<PyRegionIterator>(m, "RegionIterator")
350         .def("__iter__", &PyRegionIterator::dunderIter)
351         .def("__next__", &PyRegionIterator::dunderNext);
352   }
353 
354 private:
355   PyOperationRef operation;
356   int nextIndex = 0;
357 };
358 
359 /// Regions of an op are fixed length and indexed numerically so are represented
360 /// with a sequence-like container.
361 class PyRegionList {
362 public:
363   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
364 
365   PyRegionIterator dunderIter() {
366     operation->checkValid();
367     return PyRegionIterator(operation);
368   }
369 
370   intptr_t dunderLen() {
371     operation->checkValid();
372     return mlirOperationGetNumRegions(operation->get());
373   }
374 
375   PyRegion dunderGetItem(intptr_t index) {
376     // dunderLen checks validity.
377     if (index < 0 || index >= dunderLen()) {
378       throw nb::index_error("attempt to access out of bounds region");
379     }
380     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
381     return PyRegion(operation, region);
382   }
383 
384   static void bind(nb::module_ &m) {
385     nb::class_<PyRegionList>(m, "RegionSequence")
386         .def("__len__", &PyRegionList::dunderLen)
387         .def("__iter__", &PyRegionList::dunderIter)
388         .def("__getitem__", &PyRegionList::dunderGetItem);
389   }
390 
391 private:
392   PyOperationRef operation;
393 };
394 
395 class PyBlockIterator {
396 public:
397   PyBlockIterator(PyOperationRef operation, MlirBlock next)
398       : operation(std::move(operation)), next(next) {}
399 
400   PyBlockIterator &dunderIter() { return *this; }
401 
402   PyBlock dunderNext() {
403     operation->checkValid();
404     if (mlirBlockIsNull(next)) {
405       throw nb::stop_iteration();
406     }
407 
408     PyBlock returnBlock(operation, next);
409     next = mlirBlockGetNextInRegion(next);
410     return returnBlock;
411   }
412 
413   static void bind(nb::module_ &m) {
414     nb::class_<PyBlockIterator>(m, "BlockIterator")
415         .def("__iter__", &PyBlockIterator::dunderIter)
416         .def("__next__", &PyBlockIterator::dunderNext);
417   }
418 
419 private:
420   PyOperationRef operation;
421   MlirBlock next;
422 };
423 
424 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
425 /// we present them as a more full-featured list-like container but optimize
426 /// it for forward iteration. Blocks are always owned by a region.
427 class PyBlockList {
428 public:
429   PyBlockList(PyOperationRef operation, MlirRegion region)
430       : operation(std::move(operation)), region(region) {}
431 
432   PyBlockIterator dunderIter() {
433     operation->checkValid();
434     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
435   }
436 
437   intptr_t dunderLen() {
438     operation->checkValid();
439     intptr_t count = 0;
440     MlirBlock block = mlirRegionGetFirstBlock(region);
441     while (!mlirBlockIsNull(block)) {
442       count += 1;
443       block = mlirBlockGetNextInRegion(block);
444     }
445     return count;
446   }
447 
448   PyBlock dunderGetItem(intptr_t index) {
449     operation->checkValid();
450     if (index < 0) {
451       throw nb::index_error("attempt to access out of bounds block");
452     }
453     MlirBlock block = mlirRegionGetFirstBlock(region);
454     while (!mlirBlockIsNull(block)) {
455       if (index == 0) {
456         return PyBlock(operation, block);
457       }
458       block = mlirBlockGetNextInRegion(block);
459       index -= 1;
460     }
461     throw nb::index_error("attempt to access out of bounds block");
462   }
463 
464   PyBlock appendBlock(const nb::args &pyArgTypes,
465                       const std::optional<nb::sequence> &pyArgLocs) {
466     operation->checkValid();
467     MlirBlock block =
468         createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
469     mlirRegionAppendOwnedBlock(region, block);
470     return PyBlock(operation, block);
471   }
472 
473   static void bind(nb::module_ &m) {
474     nb::class_<PyBlockList>(m, "BlockList")
475         .def("__getitem__", &PyBlockList::dunderGetItem)
476         .def("__iter__", &PyBlockList::dunderIter)
477         .def("__len__", &PyBlockList::dunderLen)
478         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
479              nb::arg("args"), nb::kw_only(),
480              nb::arg("arg_locs") = std::nullopt);
481   }
482 
483 private:
484   PyOperationRef operation;
485   MlirRegion region;
486 };
487 
488 class PyOperationIterator {
489 public:
490   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
491       : parentOperation(std::move(parentOperation)), next(next) {}
492 
493   PyOperationIterator &dunderIter() { return *this; }
494 
495   nb::object dunderNext() {
496     parentOperation->checkValid();
497     if (mlirOperationIsNull(next)) {
498       throw nb::stop_iteration();
499     }
500 
501     PyOperationRef returnOperation =
502         PyOperation::forOperation(parentOperation->getContext(), next);
503     next = mlirOperationGetNextInBlock(next);
504     return returnOperation->createOpView();
505   }
506 
507   static void bind(nb::module_ &m) {
508     nb::class_<PyOperationIterator>(m, "OperationIterator")
509         .def("__iter__", &PyOperationIterator::dunderIter)
510         .def("__next__", &PyOperationIterator::dunderNext);
511   }
512 
513 private:
514   PyOperationRef parentOperation;
515   MlirOperation next;
516 };
517 
518 /// Operations are exposed by the C-API as a forward-only linked list. In
519 /// Python, we present them as a more full-featured list-like container but
520 /// optimize it for forward iteration. Iterable operations are always owned
521 /// by a block.
522 class PyOperationList {
523 public:
524   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
525       : parentOperation(std::move(parentOperation)), block(block) {}
526 
527   PyOperationIterator dunderIter() {
528     parentOperation->checkValid();
529     return PyOperationIterator(parentOperation,
530                                mlirBlockGetFirstOperation(block));
531   }
532 
533   intptr_t dunderLen() {
534     parentOperation->checkValid();
535     intptr_t count = 0;
536     MlirOperation childOp = mlirBlockGetFirstOperation(block);
537     while (!mlirOperationIsNull(childOp)) {
538       count += 1;
539       childOp = mlirOperationGetNextInBlock(childOp);
540     }
541     return count;
542   }
543 
544   nb::object dunderGetItem(intptr_t index) {
545     parentOperation->checkValid();
546     if (index < 0) {
547       throw nb::index_error("attempt to access out of bounds operation");
548     }
549     MlirOperation childOp = mlirBlockGetFirstOperation(block);
550     while (!mlirOperationIsNull(childOp)) {
551       if (index == 0) {
552         return PyOperation::forOperation(parentOperation->getContext(), childOp)
553             ->createOpView();
554       }
555       childOp = mlirOperationGetNextInBlock(childOp);
556       index -= 1;
557     }
558     throw nb::index_error("attempt to access out of bounds operation");
559   }
560 
561   static void bind(nb::module_ &m) {
562     nb::class_<PyOperationList>(m, "OperationList")
563         .def("__getitem__", &PyOperationList::dunderGetItem)
564         .def("__iter__", &PyOperationList::dunderIter)
565         .def("__len__", &PyOperationList::dunderLen);
566   }
567 
568 private:
569   PyOperationRef parentOperation;
570   MlirBlock block;
571 };
572 
573 class PyOpOperand {
574 public:
575   PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
576 
577   nb::object getOwner() {
578     MlirOperation owner = mlirOpOperandGetOwner(opOperand);
579     PyMlirContextRef context =
580         PyMlirContext::forContext(mlirOperationGetContext(owner));
581     return PyOperation::forOperation(context, owner)->createOpView();
582   }
583 
584   size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
585 
586   static void bind(nb::module_ &m) {
587     nb::class_<PyOpOperand>(m, "OpOperand")
588         .def_prop_ro("owner", &PyOpOperand::getOwner)
589         .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber);
590   }
591 
592 private:
593   MlirOpOperand opOperand;
594 };
595 
596 class PyOpOperandIterator {
597 public:
598   PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
599 
600   PyOpOperandIterator &dunderIter() { return *this; }
601 
602   PyOpOperand dunderNext() {
603     if (mlirOpOperandIsNull(opOperand))
604       throw nb::stop_iteration();
605 
606     PyOpOperand returnOpOperand(opOperand);
607     opOperand = mlirOpOperandGetNextUse(opOperand);
608     return returnOpOperand;
609   }
610 
611   static void bind(nb::module_ &m) {
612     nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
613         .def("__iter__", &PyOpOperandIterator::dunderIter)
614         .def("__next__", &PyOpOperandIterator::dunderNext);
615   }
616 
617 private:
618   MlirOpOperand opOperand;
619 };
620 
621 } // namespace
622 
623 //------------------------------------------------------------------------------
624 // PyMlirContext
625 //------------------------------------------------------------------------------
626 
627 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
628   nb::gil_scoped_acquire acquire;
629   nb::ft_lock_guard lock(live_contexts_mutex);
630   auto &liveContexts = getLiveContexts();
631   liveContexts[context.ptr] = this;
632 }
633 
634 PyMlirContext::~PyMlirContext() {
635   // Note that the only public way to construct an instance is via the
636   // forContext method, which always puts the associated handle into
637   // liveContexts.
638   nb::gil_scoped_acquire acquire;
639   {
640     nb::ft_lock_guard lock(live_contexts_mutex);
641     getLiveContexts().erase(context.ptr);
642   }
643   mlirContextDestroy(context);
644 }
645 
646 nb::object PyMlirContext::getCapsule() {
647   return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
648 }
649 
650 nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
651   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
652   if (mlirContextIsNull(rawContext))
653     throw nb::python_error();
654   return forContext(rawContext).releaseObject();
655 }
656 
657 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
658   nb::gil_scoped_acquire acquire;
659   nb::ft_lock_guard lock(live_contexts_mutex);
660   auto &liveContexts = getLiveContexts();
661   auto it = liveContexts.find(context.ptr);
662   if (it == liveContexts.end()) {
663     // Create.
664     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
665     nb::object pyRef = nb::cast(unownedContextWrapper);
666     assert(pyRef && "cast to nb::object failed");
667     liveContexts[context.ptr] = unownedContextWrapper;
668     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
669   }
670   // Use existing.
671   nb::object pyRef = nb::cast(it->second);
672   return PyMlirContextRef(it->second, std::move(pyRef));
673 }
674 
675 nb::ft_mutex PyMlirContext::live_contexts_mutex;
676 
677 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
678   static LiveContextMap liveContexts;
679   return liveContexts;
680 }
681 
682 size_t PyMlirContext::getLiveCount() {
683   nb::ft_lock_guard lock(live_contexts_mutex);
684   return getLiveContexts().size();
685 }
686 
687 size_t PyMlirContext::getLiveOperationCount() {
688   nb::ft_lock_guard lock(liveOperationsMutex);
689   return liveOperations.size();
690 }
691 
692 std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
693   std::vector<PyOperation *> liveObjects;
694   nb::ft_lock_guard lock(liveOperationsMutex);
695   for (auto &entry : liveOperations)
696     liveObjects.push_back(entry.second.second);
697   return liveObjects;
698 }
699 
700 size_t PyMlirContext::clearLiveOperations() {
701 
702   LiveOperationMap operations;
703   {
704     nb::ft_lock_guard lock(liveOperationsMutex);
705     std::swap(operations, liveOperations);
706   }
707   for (auto &op : operations)
708     op.second.second->setInvalid();
709   size_t numInvalidated = operations.size();
710   return numInvalidated;
711 }
712 
713 void PyMlirContext::clearOperation(MlirOperation op) {
714   PyOperation *py_op;
715   {
716     nb::ft_lock_guard lock(liveOperationsMutex);
717     auto it = liveOperations.find(op.ptr);
718     if (it == liveOperations.end()) {
719       return;
720     }
721     py_op = it->second.second;
722     liveOperations.erase(it);
723   }
724   py_op->setInvalid();
725 }
726 
727 void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
728   typedef struct {
729     PyOperation &rootOp;
730     bool rootSeen;
731   } callBackData;
732   callBackData data{op.getOperation(), false};
733   // Mark all ops below the op that the passmanager will be rooted
734   // at (but not op itself - note the preorder) as invalid.
735   MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
736                                                       void *userData) {
737     callBackData *data = static_cast<callBackData *>(userData);
738     if (LLVM_LIKELY(data->rootSeen))
739       data->rootOp.getOperation().getContext()->clearOperation(op);
740     else
741       data->rootSeen = true;
742     return MlirWalkResult::MlirWalkResultAdvance;
743   };
744   mlirOperationWalk(op.getOperation(), invalidatingCallback,
745                     static_cast<void *>(&data), MlirWalkPreOrder);
746 }
747 void PyMlirContext::clearOperationsInside(MlirOperation op) {
748   PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
749   clearOperationsInside(opRef->getOperation());
750 }
751 
752 void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
753   MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
754                                                       void *userData) {
755     PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
756     contextRef->clearOperation(op);
757     return MlirWalkResult::MlirWalkResultAdvance;
758   };
759   mlirOperationWalk(op.getOperation(), invalidatingCallback,
760                     &op.getOperation().getContext(), MlirWalkPreOrder);
761 }
762 
763 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
764 
765 nb::object PyMlirContext::contextEnter(nb::object context) {
766   return PyThreadContextEntry::pushContext(context);
767 }
768 
769 void PyMlirContext::contextExit(const nb::object &excType,
770                                 const nb::object &excVal,
771                                 const nb::object &excTb) {
772   PyThreadContextEntry::popContext(*this);
773 }
774 
775 nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
776   // Note that ownership is transferred to the delete callback below by way of
777   // an explicit inc_ref (borrow).
778   PyDiagnosticHandler *pyHandler =
779       new PyDiagnosticHandler(get(), std::move(callback));
780   nb::object pyHandlerObject =
781       nb::cast(pyHandler, nb::rv_policy::take_ownership);
782   pyHandlerObject.inc_ref();
783 
784   // In these C callbacks, the userData is a PyDiagnosticHandler* that is
785   // guaranteed to be known to pybind.
786   auto handlerCallback =
787       +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
788     PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
789     nb::object pyDiagnosticObject =
790         nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
791 
792     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
793     bool result = false;
794     {
795       // Since this can be called from arbitrary C++ contexts, always get the
796       // gil.
797       nb::gil_scoped_acquire gil;
798       try {
799         result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
800       } catch (std::exception &e) {
801         fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
802                 e.what());
803         pyHandler->hadError = true;
804       }
805     }
806 
807     pyDiagnostic->invalidate();
808     return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
809   };
810   auto deleteCallback = +[](void *userData) {
811     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
812     assert(pyHandler->registeredID && "handler is not registered");
813     pyHandler->registeredID.reset();
814 
815     // Decrement reference, balancing the inc_ref() above.
816     nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
817     pyHandlerObject.dec_ref();
818   };
819 
820   pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
821       get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
822   return pyHandlerObject;
823 }
824 
825 MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
826                                                        void *userData) {
827   auto *self = static_cast<ErrorCapture *>(userData);
828   // Check if the context requested we emit errors instead of capturing them.
829   if (self->ctx->emitErrorDiagnostics)
830     return mlirLogicalResultFailure();
831 
832   if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
833     return mlirLogicalResultFailure();
834 
835   self->errors.emplace_back(PyDiagnostic(diag).getInfo());
836   return mlirLogicalResultSuccess();
837 }
838 
839 PyMlirContext &DefaultingPyMlirContext::resolve() {
840   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
841   if (!context) {
842     throw std::runtime_error(
843         "An MLIR function requires a Context but none was provided in the call "
844         "or from the surrounding environment. Either pass to the function with "
845         "a 'context=' argument or establish a default using 'with Context():'");
846   }
847   return *context;
848 }
849 
850 //------------------------------------------------------------------------------
851 // PyThreadContextEntry management
852 //------------------------------------------------------------------------------
853 
854 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
855   static thread_local std::vector<PyThreadContextEntry> stack;
856   return stack;
857 }
858 
859 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
860   auto &stack = getStack();
861   if (stack.empty())
862     return nullptr;
863   return &stack.back();
864 }
865 
866 void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
867                                 nb::object insertionPoint,
868                                 nb::object location) {
869   auto &stack = getStack();
870   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
871                      std::move(location));
872   // If the new stack has more than one entry and the context of the new top
873   // entry matches the previous, copy the insertionPoint and location from the
874   // previous entry if missing from the new top entry.
875   if (stack.size() > 1) {
876     auto &prev = *(stack.rbegin() + 1);
877     auto &current = stack.back();
878     if (current.context.is(prev.context)) {
879       // Default non-context objects from the previous entry.
880       if (!current.insertionPoint)
881         current.insertionPoint = prev.insertionPoint;
882       if (!current.location)
883         current.location = prev.location;
884     }
885   }
886 }
887 
888 PyMlirContext *PyThreadContextEntry::getContext() {
889   if (!context)
890     return nullptr;
891   return nb::cast<PyMlirContext *>(context);
892 }
893 
894 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
895   if (!insertionPoint)
896     return nullptr;
897   return nb::cast<PyInsertionPoint *>(insertionPoint);
898 }
899 
900 PyLocation *PyThreadContextEntry::getLocation() {
901   if (!location)
902     return nullptr;
903   return nb::cast<PyLocation *>(location);
904 }
905 
906 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
907   auto *tos = getTopOfStack();
908   return tos ? tos->getContext() : nullptr;
909 }
910 
911 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
912   auto *tos = getTopOfStack();
913   return tos ? tos->getInsertionPoint() : nullptr;
914 }
915 
916 PyLocation *PyThreadContextEntry::getDefaultLocation() {
917   auto *tos = getTopOfStack();
918   return tos ? tos->getLocation() : nullptr;
919 }
920 
921 nb::object PyThreadContextEntry::pushContext(nb::object context) {
922   push(FrameKind::Context, /*context=*/context,
923        /*insertionPoint=*/nb::object(),
924        /*location=*/nb::object());
925   return context;
926 }
927 
928 void PyThreadContextEntry::popContext(PyMlirContext &context) {
929   auto &stack = getStack();
930   if (stack.empty())
931     throw std::runtime_error("Unbalanced Context enter/exit");
932   auto &tos = stack.back();
933   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
934     throw std::runtime_error("Unbalanced Context enter/exit");
935   stack.pop_back();
936 }
937 
938 nb::object
939 PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
940   PyInsertionPoint &insertionPoint =
941       nb::cast<PyInsertionPoint &>(insertionPointObj);
942   nb::object contextObj =
943       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
944   push(FrameKind::InsertionPoint,
945        /*context=*/contextObj,
946        /*insertionPoint=*/insertionPointObj,
947        /*location=*/nb::object());
948   return insertionPointObj;
949 }
950 
951 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
952   auto &stack = getStack();
953   if (stack.empty())
954     throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
955   auto &tos = stack.back();
956   if (tos.frameKind != FrameKind::InsertionPoint &&
957       tos.getInsertionPoint() != &insertionPoint)
958     throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
959   stack.pop_back();
960 }
961 
962 nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
963   PyLocation &location = nb::cast<PyLocation &>(locationObj);
964   nb::object contextObj = location.getContext().getObject();
965   push(FrameKind::Location, /*context=*/contextObj,
966        /*insertionPoint=*/nb::object(),
967        /*location=*/locationObj);
968   return locationObj;
969 }
970 
971 void PyThreadContextEntry::popLocation(PyLocation &location) {
972   auto &stack = getStack();
973   if (stack.empty())
974     throw std::runtime_error("Unbalanced Location enter/exit");
975   auto &tos = stack.back();
976   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
977     throw std::runtime_error("Unbalanced Location enter/exit");
978   stack.pop_back();
979 }
980 
981 //------------------------------------------------------------------------------
982 // PyDiagnostic*
983 //------------------------------------------------------------------------------
984 
985 void PyDiagnostic::invalidate() {
986   valid = false;
987   if (materializedNotes) {
988     for (nb::handle noteObject : *materializedNotes) {
989       PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
990       note->invalidate();
991     }
992   }
993 }
994 
995 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
996                                          nb::object callback)
997     : context(context), callback(std::move(callback)) {}
998 
999 PyDiagnosticHandler::~PyDiagnosticHandler() = default;
1000 
1001 void PyDiagnosticHandler::detach() {
1002   if (!registeredID)
1003     return;
1004   MlirDiagnosticHandlerID localID = *registeredID;
1005   mlirContextDetachDiagnosticHandler(context, localID);
1006   assert(!registeredID && "should have unregistered");
1007   // Not strictly necessary but keeps stale pointers from being around to cause
1008   // issues.
1009   context = {nullptr};
1010 }
1011 
1012 void PyDiagnostic::checkValid() {
1013   if (!valid) {
1014     throw std::invalid_argument(
1015         "Diagnostic is invalid (used outside of callback)");
1016   }
1017 }
1018 
1019 MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
1020   checkValid();
1021   return mlirDiagnosticGetSeverity(diagnostic);
1022 }
1023 
1024 PyLocation PyDiagnostic::getLocation() {
1025   checkValid();
1026   MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
1027   MlirContext context = mlirLocationGetContext(loc);
1028   return PyLocation(PyMlirContext::forContext(context), loc);
1029 }
1030 
1031 nb::str PyDiagnostic::getMessage() {
1032   checkValid();
1033   nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
1034   PyFileAccumulator accum(fileObject, /*binary=*/false);
1035   mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
1036   return nb::cast<nb::str>(fileObject.attr("getvalue")());
1037 }
1038 
1039 nb::tuple PyDiagnostic::getNotes() {
1040   checkValid();
1041   if (materializedNotes)
1042     return *materializedNotes;
1043   intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
1044   nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
1045   for (intptr_t i = 0; i < numNotes; ++i) {
1046     MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
1047     nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
1048     PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr());
1049   }
1050   materializedNotes = std::move(notes);
1051 
1052   return *materializedNotes;
1053 }
1054 
1055 PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() {
1056   std::vector<DiagnosticInfo> notes;
1057   for (nb::handle n : getNotes())
1058     notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
1059   return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
1060           std::move(notes)};
1061 }
1062 
1063 //------------------------------------------------------------------------------
1064 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
1065 //------------------------------------------------------------------------------
1066 
1067 MlirDialect PyDialects::getDialectForKey(const std::string &key,
1068                                          bool attrError) {
1069   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
1070                                                     {key.data(), key.size()});
1071   if (mlirDialectIsNull(dialect)) {
1072     std::string msg = (Twine("Dialect '") + key + "' not found").str();
1073     if (attrError)
1074       throw nb::attribute_error(msg.c_str());
1075     throw nb::index_error(msg.c_str());
1076   }
1077   return dialect;
1078 }
1079 
1080 nb::object PyDialectRegistry::getCapsule() {
1081   return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
1082 }
1083 
1084 PyDialectRegistry PyDialectRegistry::createFromCapsule(nb::object capsule) {
1085   MlirDialectRegistry rawRegistry =
1086       mlirPythonCapsuleToDialectRegistry(capsule.ptr());
1087   if (mlirDialectRegistryIsNull(rawRegistry))
1088     throw nb::python_error();
1089   return PyDialectRegistry(rawRegistry);
1090 }
1091 
1092 //------------------------------------------------------------------------------
1093 // PyLocation
1094 //------------------------------------------------------------------------------
1095 
1096 nb::object PyLocation::getCapsule() {
1097   return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
1098 }
1099 
1100 PyLocation PyLocation::createFromCapsule(nb::object capsule) {
1101   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
1102   if (mlirLocationIsNull(rawLoc))
1103     throw nb::python_error();
1104   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
1105                     rawLoc);
1106 }
1107 
1108 nb::object PyLocation::contextEnter(nb::object locationObj) {
1109   return PyThreadContextEntry::pushLocation(locationObj);
1110 }
1111 
1112 void PyLocation::contextExit(const nb::object &excType,
1113                              const nb::object &excVal,
1114                              const nb::object &excTb) {
1115   PyThreadContextEntry::popLocation(*this);
1116 }
1117 
1118 PyLocation &DefaultingPyLocation::resolve() {
1119   auto *location = PyThreadContextEntry::getDefaultLocation();
1120   if (!location) {
1121     throw std::runtime_error(
1122         "An MLIR function requires a Location but none was provided in the "
1123         "call or from the surrounding environment. Either pass to the function "
1124         "with a 'loc=' argument or establish a default using 'with loc:'");
1125   }
1126   return *location;
1127 }
1128 
1129 //------------------------------------------------------------------------------
1130 // PyModule
1131 //------------------------------------------------------------------------------
1132 
1133 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
1134     : BaseContextObject(std::move(contextRef)), module(module) {}
1135 
1136 PyModule::~PyModule() {
1137   nb::gil_scoped_acquire acquire;
1138   auto &liveModules = getContext()->liveModules;
1139   assert(liveModules.count(module.ptr) == 1 &&
1140          "destroying module not in live map");
1141   liveModules.erase(module.ptr);
1142   mlirModuleDestroy(module);
1143 }
1144 
1145 PyModuleRef PyModule::forModule(MlirModule module) {
1146   MlirContext context = mlirModuleGetContext(module);
1147   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
1148 
1149   nb::gil_scoped_acquire acquire;
1150   auto &liveModules = contextRef->liveModules;
1151   auto it = liveModules.find(module.ptr);
1152   if (it == liveModules.end()) {
1153     // Create.
1154     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1155     // Note that the default return value policy on cast is automatic_reference,
1156     // which does not take ownership (delete will not be called).
1157     // Just be explicit.
1158     nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1159     unownedModule->handle = pyRef;
1160     liveModules[module.ptr] =
1161         std::make_pair(unownedModule->handle, unownedModule);
1162     return PyModuleRef(unownedModule, std::move(pyRef));
1163   }
1164   // Use existing.
1165   PyModule *existing = it->second.second;
1166   nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1167   return PyModuleRef(existing, std::move(pyRef));
1168 }
1169 
1170 nb::object PyModule::createFromCapsule(nb::object capsule) {
1171   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1172   if (mlirModuleIsNull(rawModule))
1173     throw nb::python_error();
1174   return forModule(rawModule).releaseObject();
1175 }
1176 
1177 nb::object PyModule::getCapsule() {
1178   return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
1179 }
1180 
1181 //------------------------------------------------------------------------------
1182 // PyOperation
1183 //------------------------------------------------------------------------------
1184 
1185 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1186     : BaseContextObject(std::move(contextRef)), operation(operation) {}
1187 
1188 PyOperation::~PyOperation() {
1189   // If the operation has already been invalidated there is nothing to do.
1190   if (!valid)
1191     return;
1192 
1193   // Otherwise, invalidate the operation and remove it from live map when it is
1194   // attached.
1195   if (isAttached()) {
1196     getContext()->clearOperation(*this);
1197   } else {
1198     // And destroy it when it is detached, i.e. owned by Python, in which case
1199     // all nested operations must be invalidated at removed from the live map as
1200     // well.
1201     erase();
1202   }
1203 }
1204 
1205 namespace {
1206 
1207 // Constructs a new object of type T in-place on the Python heap, returning a
1208 // PyObjectRef to it, loosely analogous to std::make_shared<T>().
1209 template <typename T, class... Args>
1210 PyObjectRef<T> makeObjectRef(Args &&...args) {
1211   nb::handle type = nb::type<T>();
1212   nb::object instance = nb::inst_alloc(type);
1213   T *ptr = nb::inst_ptr<T>(instance);
1214   new (ptr) T(std::forward<Args>(args)...);
1215   nb::inst_mark_ready(instance);
1216   return PyObjectRef<T>(ptr, std::move(instance));
1217 }
1218 
1219 } // namespace
1220 
1221 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1222                                            MlirOperation operation,
1223                                            nb::object parentKeepAlive) {
1224   // Create.
1225   PyOperationRef unownedOperation =
1226       makeObjectRef<PyOperation>(std::move(contextRef), operation);
1227   unownedOperation->handle = unownedOperation.getObject();
1228   if (parentKeepAlive) {
1229     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1230   }
1231   return unownedOperation;
1232 }
1233 
1234 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
1235                                          MlirOperation operation,
1236                                          nb::object parentKeepAlive) {
1237   nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1238   auto &liveOperations = contextRef->liveOperations;
1239   auto it = liveOperations.find(operation.ptr);
1240   if (it == liveOperations.end()) {
1241     // Create.
1242     PyOperationRef result = createInstance(std::move(contextRef), operation,
1243                                            std::move(parentKeepAlive));
1244     liveOperations[operation.ptr] =
1245         std::make_pair(result.getObject(), result.get());
1246     return result;
1247   }
1248   // Use existing.
1249   PyOperation *existing = it->second.second;
1250   nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1251   return PyOperationRef(existing, std::move(pyRef));
1252 }
1253 
1254 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
1255                                            MlirOperation operation,
1256                                            nb::object parentKeepAlive) {
1257   nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1258   auto &liveOperations = contextRef->liveOperations;
1259   assert(liveOperations.count(operation.ptr) == 0 &&
1260          "cannot create detached operation that already exists");
1261   (void)liveOperations;
1262   PyOperationRef created = createInstance(std::move(contextRef), operation,
1263                                           std::move(parentKeepAlive));
1264   liveOperations[operation.ptr] =
1265       std::make_pair(created.getObject(), created.get());
1266   created->attached = false;
1267   return created;
1268 }
1269 
1270 PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
1271                                   const std::string &sourceStr,
1272                                   const std::string &sourceName) {
1273   PyMlirContext::ErrorCapture errors(contextRef);
1274   MlirOperation op =
1275       mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1276                                toMlirStringRef(sourceName));
1277   if (mlirOperationIsNull(op))
1278     throw MLIRError("Unable to parse operation assembly", errors.take());
1279   return PyOperation::createDetached(std::move(contextRef), op);
1280 }
1281 
1282 void PyOperation::checkValid() const {
1283   if (!valid) {
1284     throw std::runtime_error("the operation has been invalidated");
1285   }
1286 }
1287 
1288 void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1289                             bool enableDebugInfo, bool prettyDebugInfo,
1290                             bool printGenericOpForm, bool useLocalScope,
1291                             bool assumeVerified, nb::object fileObject,
1292                             bool binary, bool skipRegions) {
1293   PyOperation &operation = getOperation();
1294   operation.checkValid();
1295   if (fileObject.is_none())
1296     fileObject = nb::module_::import_("sys").attr("stdout");
1297 
1298   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1299   if (largeElementsLimit)
1300     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1301   if (enableDebugInfo)
1302     mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1303                                        /*prettyForm=*/prettyDebugInfo);
1304   if (printGenericOpForm)
1305     mlirOpPrintingFlagsPrintGenericOpForm(flags);
1306   if (useLocalScope)
1307     mlirOpPrintingFlagsUseLocalScope(flags);
1308   if (assumeVerified)
1309     mlirOpPrintingFlagsAssumeVerified(flags);
1310   if (skipRegions)
1311     mlirOpPrintingFlagsSkipRegions(flags);
1312 
1313   PyFileAccumulator accum(fileObject, binary);
1314   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1315                               accum.getUserData());
1316   mlirOpPrintingFlagsDestroy(flags);
1317 }
1318 
1319 void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
1320                             bool binary) {
1321   PyOperation &operation = getOperation();
1322   operation.checkValid();
1323   if (fileObject.is_none())
1324     fileObject = nb::module_::import_("sys").attr("stdout");
1325   PyFileAccumulator accum(fileObject, binary);
1326   mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1327                               accum.getUserData());
1328 }
1329 
1330 void PyOperationBase::writeBytecode(const nb::object &fileObject,
1331                                     std::optional<int64_t> bytecodeVersion) {
1332   PyOperation &operation = getOperation();
1333   operation.checkValid();
1334   PyFileAccumulator accum(fileObject, /*binary=*/true);
1335 
1336   if (!bytecodeVersion.has_value())
1337     return mlirOperationWriteBytecode(operation, accum.getCallback(),
1338                                       accum.getUserData());
1339 
1340   MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1341   mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
1342   MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
1343       operation, config, accum.getCallback(), accum.getUserData());
1344   mlirBytecodeWriterConfigDestroy(config);
1345   if (mlirLogicalResultIsFailure(res))
1346     throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
1347                            Twine(*bytecodeVersion))
1348                               .str()
1349                               .c_str());
1350 }
1351 
1352 void PyOperationBase::walk(
1353     std::function<MlirWalkResult(MlirOperation)> callback,
1354     MlirWalkOrder walkOrder) {
1355   PyOperation &operation = getOperation();
1356   operation.checkValid();
1357   struct UserData {
1358     std::function<MlirWalkResult(MlirOperation)> callback;
1359     bool gotException;
1360     std::string exceptionWhat;
1361     nb::object exceptionType;
1362   };
1363   UserData userData{callback, false, {}, {}};
1364   MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1365                                               void *userData) {
1366     UserData *calleeUserData = static_cast<UserData *>(userData);
1367     try {
1368       return (calleeUserData->callback)(op);
1369     } catch (nb::python_error &e) {
1370       calleeUserData->gotException = true;
1371       calleeUserData->exceptionWhat = std::string(e.what());
1372       calleeUserData->exceptionType = nb::borrow(e.type());
1373       return MlirWalkResult::MlirWalkResultInterrupt;
1374     }
1375   };
1376   mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1377   if (userData.gotException) {
1378     std::string message("Exception raised in callback: ");
1379     message.append(userData.exceptionWhat);
1380     throw std::runtime_error(message);
1381   }
1382 }
1383 
1384 nb::object PyOperationBase::getAsm(bool binary,
1385                                    std::optional<int64_t> largeElementsLimit,
1386                                    bool enableDebugInfo, bool prettyDebugInfo,
1387                                    bool printGenericOpForm, bool useLocalScope,
1388                                    bool assumeVerified, bool skipRegions) {
1389   nb::object fileObject;
1390   if (binary) {
1391     fileObject = nb::module_::import_("io").attr("BytesIO")();
1392   } else {
1393     fileObject = nb::module_::import_("io").attr("StringIO")();
1394   }
1395   print(/*largeElementsLimit=*/largeElementsLimit,
1396         /*enableDebugInfo=*/enableDebugInfo,
1397         /*prettyDebugInfo=*/prettyDebugInfo,
1398         /*printGenericOpForm=*/printGenericOpForm,
1399         /*useLocalScope=*/useLocalScope,
1400         /*assumeVerified=*/assumeVerified,
1401         /*fileObject=*/fileObject,
1402         /*binary=*/binary,
1403         /*skipRegions=*/skipRegions);
1404 
1405   return fileObject.attr("getvalue")();
1406 }
1407 
1408 void PyOperationBase::moveAfter(PyOperationBase &other) {
1409   PyOperation &operation = getOperation();
1410   PyOperation &otherOp = other.getOperation();
1411   operation.checkValid();
1412   otherOp.checkValid();
1413   mlirOperationMoveAfter(operation, otherOp);
1414   operation.parentKeepAlive = otherOp.parentKeepAlive;
1415 }
1416 
1417 void PyOperationBase::moveBefore(PyOperationBase &other) {
1418   PyOperation &operation = getOperation();
1419   PyOperation &otherOp = other.getOperation();
1420   operation.checkValid();
1421   otherOp.checkValid();
1422   mlirOperationMoveBefore(operation, otherOp);
1423   operation.parentKeepAlive = otherOp.parentKeepAlive;
1424 }
1425 
1426 bool PyOperationBase::verify() {
1427   PyOperation &op = getOperation();
1428   PyMlirContext::ErrorCapture errors(op.getContext());
1429   if (!mlirOperationVerify(op.get()))
1430     throw MLIRError("Verification failed", errors.take());
1431   return true;
1432 }
1433 
1434 std::optional<PyOperationRef> PyOperation::getParentOperation() {
1435   checkValid();
1436   if (!isAttached())
1437     throw nb::value_error("Detached operations have no parent");
1438   MlirOperation operation = mlirOperationGetParentOperation(get());
1439   if (mlirOperationIsNull(operation))
1440     return {};
1441   return PyOperation::forOperation(getContext(), operation);
1442 }
1443 
1444 PyBlock PyOperation::getBlock() {
1445   checkValid();
1446   std::optional<PyOperationRef> parentOperation = getParentOperation();
1447   MlirBlock block = mlirOperationGetBlock(get());
1448   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1449   assert(parentOperation && "Operation has no parent");
1450   return PyBlock{std::move(*parentOperation), block};
1451 }
1452 
1453 nb::object PyOperation::getCapsule() {
1454   checkValid();
1455   return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
1456 }
1457 
1458 nb::object PyOperation::createFromCapsule(nb::object capsule) {
1459   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1460   if (mlirOperationIsNull(rawOperation))
1461     throw nb::python_error();
1462   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1463   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1464       .releaseObject();
1465 }
1466 
1467 static void maybeInsertOperation(PyOperationRef &op,
1468                                  const nb::object &maybeIp) {
1469   // InsertPoint active?
1470   if (!maybeIp.is(nb::cast(false))) {
1471     PyInsertionPoint *ip;
1472     if (maybeIp.is_none()) {
1473       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1474     } else {
1475       ip = nb::cast<PyInsertionPoint *>(maybeIp);
1476     }
1477     if (ip)
1478       ip->insert(*op.get());
1479   }
1480 }
1481 
1482 nb::object PyOperation::create(std::string_view name,
1483                                std::optional<std::vector<PyType *>> results,
1484                                llvm::ArrayRef<MlirValue> operands,
1485                                std::optional<nb::dict> attributes,
1486                                std::optional<std::vector<PyBlock *>> successors,
1487                                int regions, DefaultingPyLocation location,
1488                                const nb::object &maybeIp, bool inferType) {
1489   llvm::SmallVector<MlirType, 4> mlirResults;
1490   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1491   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1492 
1493   // General parameter validation.
1494   if (regions < 0)
1495     throw nb::value_error("number of regions must be >= 0");
1496 
1497   // Unpack/validate results.
1498   if (results) {
1499     mlirResults.reserve(results->size());
1500     for (PyType *result : *results) {
1501       // TODO: Verify result type originate from the same context.
1502       if (!result)
1503         throw nb::value_error("result type cannot be None");
1504       mlirResults.push_back(*result);
1505     }
1506   }
1507   // Unpack/validate attributes.
1508   if (attributes) {
1509     mlirAttributes.reserve(attributes->size());
1510     for (std::pair<nb::handle, nb::handle> it : *attributes) {
1511       std::string key;
1512       try {
1513         key = nb::cast<std::string>(it.first);
1514       } catch (nb::cast_error &err) {
1515         std::string msg = "Invalid attribute key (not a string) when "
1516                           "attempting to create the operation \"" +
1517                           std::string(name) + "\" (" + err.what() + ")";
1518         throw nb::type_error(msg.c_str());
1519       }
1520       try {
1521         auto &attribute = nb::cast<PyAttribute &>(it.second);
1522         // TODO: Verify attribute originates from the same context.
1523         mlirAttributes.emplace_back(std::move(key), attribute);
1524       } catch (nb::cast_error &err) {
1525         std::string msg = "Invalid attribute value for the key \"" + key +
1526                           "\" when attempting to create the operation \"" +
1527                           std::string(name) + "\" (" + err.what() + ")";
1528         throw nb::type_error(msg.c_str());
1529       } catch (std::runtime_error &) {
1530         // This exception seems thrown when the value is "None".
1531         std::string msg =
1532             "Found an invalid (`None`?) attribute value for the key \"" + key +
1533             "\" when attempting to create the operation \"" +
1534             std::string(name) + "\"";
1535         throw std::runtime_error(msg);
1536       }
1537     }
1538   }
1539   // Unpack/validate successors.
1540   if (successors) {
1541     mlirSuccessors.reserve(successors->size());
1542     for (auto *successor : *successors) {
1543       // TODO: Verify successor originate from the same context.
1544       if (!successor)
1545         throw nb::value_error("successor block cannot be None");
1546       mlirSuccessors.push_back(successor->get());
1547     }
1548   }
1549 
1550   // Apply unpacked/validated to the operation state. Beyond this
1551   // point, exceptions cannot be thrown or else the state will leak.
1552   MlirOperationState state =
1553       mlirOperationStateGet(toMlirStringRef(name), location);
1554   if (!operands.empty())
1555     mlirOperationStateAddOperands(&state, operands.size(), operands.data());
1556   state.enableResultTypeInference = inferType;
1557   if (!mlirResults.empty())
1558     mlirOperationStateAddResults(&state, mlirResults.size(),
1559                                  mlirResults.data());
1560   if (!mlirAttributes.empty()) {
1561     // Note that the attribute names directly reference bytes in
1562     // mlirAttributes, so that vector must not be changed from here
1563     // on.
1564     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1565     mlirNamedAttributes.reserve(mlirAttributes.size());
1566     for (auto &it : mlirAttributes)
1567       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1568           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1569                             toMlirStringRef(it.first)),
1570           it.second));
1571     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1572                                     mlirNamedAttributes.data());
1573   }
1574   if (!mlirSuccessors.empty())
1575     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1576                                     mlirSuccessors.data());
1577   if (regions) {
1578     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1579     mlirRegions.resize(regions);
1580     for (int i = 0; i < regions; ++i)
1581       mlirRegions[i] = mlirRegionCreate();
1582     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1583                                       mlirRegions.data());
1584   }
1585 
1586   // Construct the operation.
1587   MlirOperation operation = mlirOperationCreate(&state);
1588   if (!operation.ptr)
1589     throw nb::value_error("Operation creation failed");
1590   PyOperationRef created =
1591       PyOperation::createDetached(location->getContext(), operation);
1592   maybeInsertOperation(created, maybeIp);
1593 
1594   return created.getObject();
1595 }
1596 
1597 nb::object PyOperation::clone(const nb::object &maybeIp) {
1598   MlirOperation clonedOperation = mlirOperationClone(operation);
1599   PyOperationRef cloned =
1600       PyOperation::createDetached(getContext(), clonedOperation);
1601   maybeInsertOperation(cloned, maybeIp);
1602 
1603   return cloned->createOpView();
1604 }
1605 
1606 nb::object PyOperation::createOpView() {
1607   checkValid();
1608   MlirIdentifier ident = mlirOperationGetName(get());
1609   MlirStringRef identStr = mlirIdentifierStr(ident);
1610   auto operationCls = PyGlobals::get().lookupOperationClass(
1611       StringRef(identStr.data, identStr.length));
1612   if (operationCls)
1613     return PyOpView::constructDerived(*operationCls, getRef().getObject());
1614   return nb::cast(PyOpView(getRef().getObject()));
1615 }
1616 
1617 void PyOperation::erase() {
1618   checkValid();
1619   getContext()->clearOperationAndInside(*this);
1620   mlirOperationDestroy(operation);
1621 }
1622 
1623 namespace {
1624 /// CRTP base class for Python MLIR values that subclass Value and should be
1625 /// castable from it. The value hierarchy is one level deep and is not supposed
1626 /// to accommodate other levels unless core MLIR changes.
1627 template <typename DerivedTy>
1628 class PyConcreteValue : public PyValue {
1629 public:
1630   // Derived classes must define statics for:
1631   //   IsAFunctionTy isaFunction
1632   //   const char *pyClassName
1633   // and redefine bindDerived.
1634   using ClassTy = nb::class_<DerivedTy, PyValue>;
1635   using IsAFunctionTy = bool (*)(MlirValue);
1636 
1637   PyConcreteValue() = default;
1638   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1639       : PyValue(operationRef, value) {}
1640   PyConcreteValue(PyValue &orig)
1641       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1642 
1643   /// Attempts to cast the original value to the derived type and throws on
1644   /// type mismatches.
1645   static MlirValue castFrom(PyValue &orig) {
1646     if (!DerivedTy::isaFunction(orig.get())) {
1647       auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
1648       throw nb::value_error((Twine("Cannot cast value to ") +
1649                              DerivedTy::pyClassName + " (from " + origRepr +
1650                              ")")
1651                                 .str()
1652                                 .c_str());
1653     }
1654     return orig.get();
1655   }
1656 
1657   /// Binds the Python module objects to functions of this class.
1658   static void bind(nb::module_ &m) {
1659     auto cls = ClassTy(m, DerivedTy::pyClassName);
1660     cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
1661     cls.def_static(
1662         "isinstance",
1663         [](PyValue &otherValue) -> bool {
1664           return DerivedTy::isaFunction(otherValue);
1665         },
1666         nb::arg("other_value"));
1667     cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
1668             [](DerivedTy &self) { return self.maybeDownCast(); });
1669     DerivedTy::bindDerived(cls);
1670   }
1671 
1672   /// Implemented by derived classes to add methods to the Python subclass.
1673   static void bindDerived(ClassTy &m) {}
1674 };
1675 
1676 } // namespace
1677 
1678 /// Python wrapper for MlirOpResult.
1679 class PyOpResult : public PyConcreteValue<PyOpResult> {
1680 public:
1681   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1682   static constexpr const char *pyClassName = "OpResult";
1683   using PyConcreteValue::PyConcreteValue;
1684 
1685   static void bindDerived(ClassTy &c) {
1686     c.def_prop_ro("owner", [](PyOpResult &self) {
1687       assert(
1688           mlirOperationEqual(self.getParentOperation()->get(),
1689                              mlirOpResultGetOwner(self.get())) &&
1690           "expected the owner of the value in Python to match that in the IR");
1691       return self.getParentOperation().getObject();
1692     });
1693     c.def_prop_ro("result_number", [](PyOpResult &self) {
1694       return mlirOpResultGetResultNumber(self.get());
1695     });
1696   }
1697 };
1698 
1699 /// Returns the list of types of the values held by container.
1700 template <typename Container>
1701 static std::vector<MlirType> getValueTypes(Container &container,
1702                                            PyMlirContextRef &context) {
1703   std::vector<MlirType> result;
1704   result.reserve(container.size());
1705   for (int i = 0, e = container.size(); i < e; ++i) {
1706     result.push_back(mlirValueGetType(container.getElement(i).get()));
1707   }
1708   return result;
1709 }
1710 
1711 /// A list of operation results. Internally, these are stored as consecutive
1712 /// elements, random access is cheap. The (returned) result list is associated
1713 /// with the operation whose results these are, and thus extends the lifetime of
1714 /// this operation.
1715 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1716 public:
1717   static constexpr const char *pyClassName = "OpResultList";
1718   using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
1719 
1720   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1721                  intptr_t length = -1, intptr_t step = 1)
1722       : Sliceable(startIndex,
1723                   length == -1 ? mlirOperationGetNumResults(operation->get())
1724                                : length,
1725                   step),
1726         operation(std::move(operation)) {}
1727 
1728   static void bindDerived(ClassTy &c) {
1729     c.def_prop_ro("types", [](PyOpResultList &self) {
1730       return getValueTypes(self, self.operation->getContext());
1731     });
1732     c.def_prop_ro("owner", [](PyOpResultList &self) {
1733       return self.operation->createOpView();
1734     });
1735   }
1736 
1737   PyOperationRef &getOperation() { return operation; }
1738 
1739 private:
1740   /// Give the parent CRTP class access to hook implementations below.
1741   friend class Sliceable<PyOpResultList, PyOpResult>;
1742 
1743   intptr_t getRawNumElements() {
1744     operation->checkValid();
1745     return mlirOperationGetNumResults(operation->get());
1746   }
1747 
1748   PyOpResult getRawElement(intptr_t index) {
1749     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1750     return PyOpResult(value);
1751   }
1752 
1753   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1754     return PyOpResultList(operation, startIndex, length, step);
1755   }
1756 
1757   PyOperationRef operation;
1758 };
1759 
1760 //------------------------------------------------------------------------------
1761 // PyOpView
1762 //------------------------------------------------------------------------------
1763 
1764 static void populateResultTypes(StringRef name, nb::list resultTypeList,
1765                                 const nb::object &resultSegmentSpecObj,
1766                                 std::vector<int32_t> &resultSegmentLengths,
1767                                 std::vector<PyType *> &resultTypes) {
1768   resultTypes.reserve(resultTypeList.size());
1769   if (resultSegmentSpecObj.is_none()) {
1770     // Non-variadic result unpacking.
1771     for (const auto &it : llvm::enumerate(resultTypeList)) {
1772       try {
1773         resultTypes.push_back(nb::cast<PyType *>(it.value()));
1774         if (!resultTypes.back())
1775           throw nb::cast_error();
1776       } catch (nb::cast_error &err) {
1777         throw nb::value_error((llvm::Twine("Result ") +
1778                                llvm::Twine(it.index()) + " of operation \"" +
1779                                name + "\" must be a Type (" + err.what() + ")")
1780                                   .str()
1781                                   .c_str());
1782       }
1783     }
1784   } else {
1785     // Sized result unpacking.
1786     auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1787     if (resultSegmentSpec.size() != resultTypeList.size()) {
1788       throw nb::value_error((llvm::Twine("Operation \"") + name +
1789                              "\" requires " +
1790                              llvm::Twine(resultSegmentSpec.size()) +
1791                              " result segments but was provided " +
1792                              llvm::Twine(resultTypeList.size()))
1793                                 .str()
1794                                 .c_str());
1795     }
1796     resultSegmentLengths.reserve(resultTypeList.size());
1797     for (const auto &it :
1798          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1799       int segmentSpec = std::get<1>(it.value());
1800       if (segmentSpec == 1 || segmentSpec == 0) {
1801         // Unpack unary element.
1802         try {
1803           auto *resultType = nb::cast<PyType *>(std::get<0>(it.value()));
1804           if (resultType) {
1805             resultTypes.push_back(resultType);
1806             resultSegmentLengths.push_back(1);
1807           } else if (segmentSpec == 0) {
1808             // Allowed to be optional.
1809             resultSegmentLengths.push_back(0);
1810           } else {
1811             throw nb::value_error(
1812                 (llvm::Twine("Result ") + llvm::Twine(it.index()) +
1813                  " of operation \"" + name +
1814                  "\" must be a Type (was None and result is not optional)")
1815                     .str()
1816                     .c_str());
1817           }
1818         } catch (nb::cast_error &err) {
1819           throw nb::value_error((llvm::Twine("Result ") +
1820                                  llvm::Twine(it.index()) + " of operation \"" +
1821                                  name + "\" must be a Type (" + err.what() +
1822                                  ")")
1823                                     .str()
1824                                     .c_str());
1825         }
1826       } else if (segmentSpec == -1) {
1827         // Unpack sequence by appending.
1828         try {
1829           if (std::get<0>(it.value()).is_none()) {
1830             // Treat it as an empty list.
1831             resultSegmentLengths.push_back(0);
1832           } else {
1833             // Unpack the list.
1834             auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1835             for (nb::handle segmentItem : segment) {
1836               resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1837               if (!resultTypes.back()) {
1838                 throw nb::type_error("contained a None item");
1839               }
1840             }
1841             resultSegmentLengths.push_back(nb::len(segment));
1842           }
1843         } catch (std::exception &err) {
1844           // NOTE: Sloppy to be using a catch-all here, but there are at least
1845           // three different unrelated exceptions that can be thrown in the
1846           // above "casts". Just keep the scope above small and catch them all.
1847           throw nb::value_error((llvm::Twine("Result ") +
1848                                  llvm::Twine(it.index()) + " of operation \"" +
1849                                  name + "\" must be a Sequence of Types (" +
1850                                  err.what() + ")")
1851                                     .str()
1852                                     .c_str());
1853         }
1854       } else {
1855         throw nb::value_error("Unexpected segment spec");
1856       }
1857     }
1858   }
1859 }
1860 
1861 static MlirValue getUniqueResult(MlirOperation operation) {
1862   auto numResults = mlirOperationGetNumResults(operation);
1863   if (numResults != 1) {
1864     auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1865     throw nb::value_error((Twine("Cannot call .result on operation ") +
1866                            StringRef(name.data, name.length) + " which has " +
1867                            Twine(numResults) +
1868                            " results (it is only valid for operations with a "
1869                            "single result)")
1870                               .str()
1871                               .c_str());
1872   }
1873   return mlirOperationGetResult(operation, 0);
1874 }
1875 
1876 static MlirValue getOpResultOrValue(nb::handle operand) {
1877   if (operand.is_none()) {
1878     throw nb::value_error("contained a None item");
1879   }
1880   PyOperationBase *op;
1881   if (nb::try_cast<PyOperationBase *>(operand, op)) {
1882     return getUniqueResult(op->getOperation());
1883   }
1884   PyOpResultList *opResultList;
1885   if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
1886     return getUniqueResult(opResultList->getOperation()->get());
1887   }
1888   PyValue *value;
1889   if (nb::try_cast<PyValue *>(operand, value)) {
1890     return value->get();
1891   }
1892   throw nb::value_error("is not a Value");
1893 }
1894 
1895 nb::object PyOpView::buildGeneric(
1896     std::string_view name, std::tuple<int, bool> opRegionSpec,
1897     nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1898     std::optional<nb::list> resultTypeList, nb::list operandList,
1899     std::optional<nb::dict> attributes,
1900     std::optional<std::vector<PyBlock *>> successors,
1901     std::optional<int> regions, DefaultingPyLocation location,
1902     const nb::object &maybeIp) {
1903   PyMlirContextRef context = location->getContext();
1904 
1905   // Class level operation construction metadata.
1906   // Operand and result segment specs are either none, which does no
1907   // variadic unpacking, or a list of ints with segment sizes, where each
1908   // element is either a positive number (typically 1 for a scalar) or -1 to
1909   // indicate that it is derived from the length of the same-indexed operand
1910   // or result (implying that it is a list at that position).
1911   std::vector<int32_t> operandSegmentLengths;
1912   std::vector<int32_t> resultSegmentLengths;
1913 
1914   // Validate/determine region count.
1915   int opMinRegionCount = std::get<0>(opRegionSpec);
1916   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1917   if (!regions) {
1918     regions = opMinRegionCount;
1919   }
1920   if (*regions < opMinRegionCount) {
1921     throw nb::value_error(
1922         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1923          llvm::Twine(opMinRegionCount) +
1924          " regions but was built with regions=" + llvm::Twine(*regions))
1925             .str()
1926             .c_str());
1927   }
1928   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1929     throw nb::value_error(
1930         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1931          llvm::Twine(opMinRegionCount) +
1932          " regions but was built with regions=" + llvm::Twine(*regions))
1933             .str()
1934             .c_str());
1935   }
1936 
1937   // Unpack results.
1938   std::vector<PyType *> resultTypes;
1939   if (resultTypeList.has_value()) {
1940     populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1941                         resultSegmentLengths, resultTypes);
1942   }
1943 
1944   // Unpack operands.
1945   llvm::SmallVector<MlirValue, 4> operands;
1946   operands.reserve(operands.size());
1947   if (operandSegmentSpecObj.is_none()) {
1948     // Non-sized operand unpacking.
1949     for (const auto &it : llvm::enumerate(operandList)) {
1950       try {
1951         operands.push_back(getOpResultOrValue(it.value()));
1952       } catch (nb::builtin_exception &err) {
1953         throw nb::value_error((llvm::Twine("Operand ") +
1954                                llvm::Twine(it.index()) + " of operation \"" +
1955                                name + "\" must be a Value (" + err.what() + ")")
1956                                   .str()
1957                                   .c_str());
1958       }
1959     }
1960   } else {
1961     // Sized operand unpacking.
1962     auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1963     if (operandSegmentSpec.size() != operandList.size()) {
1964       throw nb::value_error((llvm::Twine("Operation \"") + name +
1965                              "\" requires " +
1966                              llvm::Twine(operandSegmentSpec.size()) +
1967                              "operand segments but was provided " +
1968                              llvm::Twine(operandList.size()))
1969                                 .str()
1970                                 .c_str());
1971     }
1972     operandSegmentLengths.reserve(operandList.size());
1973     for (const auto &it :
1974          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1975       int segmentSpec = std::get<1>(it.value());
1976       if (segmentSpec == 1 || segmentSpec == 0) {
1977         // Unpack unary element.
1978         auto &operand = std::get<0>(it.value());
1979         if (!operand.is_none()) {
1980           try {
1981 
1982             operands.push_back(getOpResultOrValue(operand));
1983           } catch (nb::builtin_exception &err) {
1984             throw nb::value_error((llvm::Twine("Operand ") +
1985                                    llvm::Twine(it.index()) +
1986                                    " of operation \"" + name +
1987                                    "\" must be a Value (" + err.what() + ")")
1988                                       .str()
1989                                       .c_str());
1990           }
1991 
1992           operandSegmentLengths.push_back(1);
1993         } else if (segmentSpec == 0) {
1994           // Allowed to be optional.
1995           operandSegmentLengths.push_back(0);
1996         } else {
1997           throw nb::value_error(
1998               (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
1999                " of operation \"" + name +
2000                "\" must be a Value (was None and operand is not optional)")
2001                   .str()
2002                   .c_str());
2003         }
2004       } else if (segmentSpec == -1) {
2005         // Unpack sequence by appending.
2006         try {
2007           if (std::get<0>(it.value()).is_none()) {
2008             // Treat it as an empty list.
2009             operandSegmentLengths.push_back(0);
2010           } else {
2011             // Unpack the list.
2012             auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
2013             for (nb::handle segmentItem : segment) {
2014               operands.push_back(getOpResultOrValue(segmentItem));
2015             }
2016             operandSegmentLengths.push_back(nb::len(segment));
2017           }
2018         } catch (std::exception &err) {
2019           // NOTE: Sloppy to be using a catch-all here, but there are at least
2020           // three different unrelated exceptions that can be thrown in the
2021           // above "casts". Just keep the scope above small and catch them all.
2022           throw nb::value_error((llvm::Twine("Operand ") +
2023                                  llvm::Twine(it.index()) + " of operation \"" +
2024                                  name + "\" must be a Sequence of Values (" +
2025                                  err.what() + ")")
2026                                     .str()
2027                                     .c_str());
2028         }
2029       } else {
2030         throw nb::value_error("Unexpected segment spec");
2031       }
2032     }
2033   }
2034 
2035   // Merge operand/result segment lengths into attributes if needed.
2036   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
2037     // Dup.
2038     if (attributes) {
2039       attributes = nb::dict(*attributes);
2040     } else {
2041       attributes = nb::dict();
2042     }
2043     if (attributes->contains("resultSegmentSizes") ||
2044         attributes->contains("operandSegmentSizes")) {
2045       throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
2046                             "'operandSegmentSizes' attribute is unsupported. "
2047                             "Use Operation.create for such low-level access.");
2048     }
2049 
2050     // Add resultSegmentSizes attribute.
2051     if (!resultSegmentLengths.empty()) {
2052       MlirAttribute segmentLengthAttr =
2053           mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
2054                                resultSegmentLengths.data());
2055       (*attributes)["resultSegmentSizes"] =
2056           PyAttribute(context, segmentLengthAttr);
2057     }
2058 
2059     // Add operandSegmentSizes attribute.
2060     if (!operandSegmentLengths.empty()) {
2061       MlirAttribute segmentLengthAttr =
2062           mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
2063                                operandSegmentLengths.data());
2064       (*attributes)["operandSegmentSizes"] =
2065           PyAttribute(context, segmentLengthAttr);
2066     }
2067   }
2068 
2069   // Delegate to create.
2070   return PyOperation::create(name,
2071                              /*results=*/std::move(resultTypes),
2072                              /*operands=*/std::move(operands),
2073                              /*attributes=*/std::move(attributes),
2074                              /*successors=*/std::move(successors),
2075                              /*regions=*/*regions, location, maybeIp,
2076                              !resultTypeList);
2077 }
2078 
2079 nb::object PyOpView::constructDerived(const nb::object &cls,
2080                                       const nb::object &operation) {
2081   nb::handle opViewType = nb::type<PyOpView>();
2082   nb::object instance = cls.attr("__new__")(cls);
2083   opViewType.attr("__init__")(instance, operation);
2084   return instance;
2085 }
2086 
2087 PyOpView::PyOpView(const nb::object &operationObject)
2088     // Casting through the PyOperationBase base-class and then back to the
2089     // Operation lets us accept any PyOperationBase subclass.
2090     : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
2091       operationObject(operation.getRef().getObject()) {}
2092 
2093 //------------------------------------------------------------------------------
2094 // PyInsertionPoint.
2095 //------------------------------------------------------------------------------
2096 
2097 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
2098 
2099 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
2100     : refOperation(beforeOperationBase.getOperation().getRef()),
2101       block((*refOperation)->getBlock()) {}
2102 
2103 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
2104   PyOperation &operation = operationBase.getOperation();
2105   if (operation.isAttached())
2106     throw nb::value_error(
2107         "Attempt to insert operation that is already attached");
2108   block.getParentOperation()->checkValid();
2109   MlirOperation beforeOp = {nullptr};
2110   if (refOperation) {
2111     // Insert before operation.
2112     (*refOperation)->checkValid();
2113     beforeOp = (*refOperation)->get();
2114   } else {
2115     // Insert at end (before null) is only valid if the block does not
2116     // already end in a known terminator (violating this will cause assertion
2117     // failures later).
2118     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
2119       throw nb::index_error("Cannot insert operation at the end of a block "
2120                             "that already has a terminator. Did you mean to "
2121                             "use 'InsertionPoint.at_block_terminator(block)' "
2122                             "versus 'InsertionPoint(block)'?");
2123     }
2124   }
2125   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
2126   operation.setAttached();
2127 }
2128 
2129 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
2130   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
2131   if (mlirOperationIsNull(firstOp)) {
2132     // Just insert at end.
2133     return PyInsertionPoint(block);
2134   }
2135 
2136   // Insert before first op.
2137   PyOperationRef firstOpRef = PyOperation::forOperation(
2138       block.getParentOperation()->getContext(), firstOp);
2139   return PyInsertionPoint{block, std::move(firstOpRef)};
2140 }
2141 
2142 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
2143   MlirOperation terminator = mlirBlockGetTerminator(block.get());
2144   if (mlirOperationIsNull(terminator))
2145     throw nb::value_error("Block has no terminator");
2146   PyOperationRef terminatorOpRef = PyOperation::forOperation(
2147       block.getParentOperation()->getContext(), terminator);
2148   return PyInsertionPoint{block, std::move(terminatorOpRef)};
2149 }
2150 
2151 nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
2152   return PyThreadContextEntry::pushInsertionPoint(insertPoint);
2153 }
2154 
2155 void PyInsertionPoint::contextExit(const nb::object &excType,
2156                                    const nb::object &excVal,
2157                                    const nb::object &excTb) {
2158   PyThreadContextEntry::popInsertionPoint(*this);
2159 }
2160 
2161 //------------------------------------------------------------------------------
2162 // PyAttribute.
2163 //------------------------------------------------------------------------------
2164 
2165 bool PyAttribute::operator==(const PyAttribute &other) const {
2166   return mlirAttributeEqual(attr, other.attr);
2167 }
2168 
2169 nb::object PyAttribute::getCapsule() {
2170   return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
2171 }
2172 
2173 PyAttribute PyAttribute::createFromCapsule(nb::object capsule) {
2174   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
2175   if (mlirAttributeIsNull(rawAttr))
2176     throw nb::python_error();
2177   return PyAttribute(
2178       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
2179 }
2180 
2181 //------------------------------------------------------------------------------
2182 // PyNamedAttribute.
2183 //------------------------------------------------------------------------------
2184 
2185 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
2186     : ownedName(new std::string(std::move(ownedName))) {
2187   namedAttr = mlirNamedAttributeGet(
2188       mlirIdentifierGet(mlirAttributeGetContext(attr),
2189                         toMlirStringRef(*this->ownedName)),
2190       attr);
2191 }
2192 
2193 //------------------------------------------------------------------------------
2194 // PyType.
2195 //------------------------------------------------------------------------------
2196 
2197 bool PyType::operator==(const PyType &other) const {
2198   return mlirTypeEqual(type, other.type);
2199 }
2200 
2201 nb::object PyType::getCapsule() {
2202   return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
2203 }
2204 
2205 PyType PyType::createFromCapsule(nb::object capsule) {
2206   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
2207   if (mlirTypeIsNull(rawType))
2208     throw nb::python_error();
2209   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
2210                 rawType);
2211 }
2212 
2213 //------------------------------------------------------------------------------
2214 // PyTypeID.
2215 //------------------------------------------------------------------------------
2216 
2217 nb::object PyTypeID::getCapsule() {
2218   return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
2219 }
2220 
2221 PyTypeID PyTypeID::createFromCapsule(nb::object capsule) {
2222   MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
2223   if (mlirTypeIDIsNull(mlirTypeID))
2224     throw nb::python_error();
2225   return PyTypeID(mlirTypeID);
2226 }
2227 bool PyTypeID::operator==(const PyTypeID &other) const {
2228   return mlirTypeIDEqual(typeID, other.typeID);
2229 }
2230 
2231 //------------------------------------------------------------------------------
2232 // PyValue and subclasses.
2233 //------------------------------------------------------------------------------
2234 
2235 nb::object PyValue::getCapsule() {
2236   return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
2237 }
2238 
2239 nb::object PyValue::maybeDownCast() {
2240   MlirType type = mlirValueGetType(get());
2241   MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
2242   assert(!mlirTypeIDIsNull(mlirTypeID) &&
2243          "mlirTypeID was expected to be non-null.");
2244   std::optional<nb::callable> valueCaster =
2245       PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
2246   // nb::rv_policy::move means use std::move to move the return value
2247   // contents into a new instance that will be owned by Python.
2248   nb::object thisObj = nb::cast(this, nb::rv_policy::move);
2249   if (!valueCaster)
2250     return thisObj;
2251   return valueCaster.value()(thisObj);
2252 }
2253 
2254 PyValue PyValue::createFromCapsule(nb::object capsule) {
2255   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2256   if (mlirValueIsNull(value))
2257     throw nb::python_error();
2258   MlirOperation owner;
2259   if (mlirValueIsAOpResult(value))
2260     owner = mlirOpResultGetOwner(value);
2261   if (mlirValueIsABlockArgument(value))
2262     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
2263   if (mlirOperationIsNull(owner))
2264     throw nb::python_error();
2265   MlirContext ctx = mlirOperationGetContext(owner);
2266   PyOperationRef ownerRef =
2267       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
2268   return PyValue(ownerRef, value);
2269 }
2270 
2271 //------------------------------------------------------------------------------
2272 // PySymbolTable.
2273 //------------------------------------------------------------------------------
2274 
2275 PySymbolTable::PySymbolTable(PyOperationBase &operation)
2276     : operation(operation.getOperation().getRef()) {
2277   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2278   if (mlirSymbolTableIsNull(symbolTable)) {
2279     throw nb::type_error("Operation is not a Symbol Table.");
2280   }
2281 }
2282 
2283 nb::object PySymbolTable::dunderGetItem(const std::string &name) {
2284   operation->checkValid();
2285   MlirOperation symbol = mlirSymbolTableLookup(
2286       symbolTable, mlirStringRefCreate(name.data(), name.length()));
2287   if (mlirOperationIsNull(symbol))
2288     throw nb::key_error(
2289         ("Symbol '" + name + "' not in the symbol table.").c_str());
2290 
2291   return PyOperation::forOperation(operation->getContext(), symbol,
2292                                    operation.getObject())
2293       ->createOpView();
2294 }
2295 
2296 void PySymbolTable::erase(PyOperationBase &symbol) {
2297   operation->checkValid();
2298   symbol.getOperation().checkValid();
2299   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2300   // The operation is also erased, so we must invalidate it. There may be Python
2301   // references to this operation so we don't want to delete it from the list of
2302   // live operations here.
2303   symbol.getOperation().valid = false;
2304 }
2305 
2306 void PySymbolTable::dunderDel(const std::string &name) {
2307   nb::object operation = dunderGetItem(name);
2308   erase(nb::cast<PyOperationBase &>(operation));
2309 }
2310 
2311 MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
2312   operation->checkValid();
2313   symbol.getOperation().checkValid();
2314   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2315       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
2316   if (mlirAttributeIsNull(symbolAttr))
2317     throw nb::value_error("Expected operation to have a symbol name.");
2318   return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
2319 }
2320 
2321 MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
2322   // Op must already be a symbol.
2323   PyOperation &operation = symbol.getOperation();
2324   operation.checkValid();
2325   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
2326   MlirAttribute existingNameAttr =
2327       mlirOperationGetAttributeByName(operation.get(), attrName);
2328   if (mlirAttributeIsNull(existingNameAttr))
2329     throw nb::value_error("Expected operation to have a symbol name.");
2330   return existingNameAttr;
2331 }
2332 
2333 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
2334                                   const std::string &name) {
2335   // Op must already be a symbol.
2336   PyOperation &operation = symbol.getOperation();
2337   operation.checkValid();
2338   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
2339   MlirAttribute existingNameAttr =
2340       mlirOperationGetAttributeByName(operation.get(), attrName);
2341   if (mlirAttributeIsNull(existingNameAttr))
2342     throw nb::value_error("Expected operation to have a symbol name.");
2343   MlirAttribute newNameAttr =
2344       mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2345   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2346 }
2347 
2348 MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
2349   PyOperation &operation = symbol.getOperation();
2350   operation.checkValid();
2351   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
2352   MlirAttribute existingVisAttr =
2353       mlirOperationGetAttributeByName(operation.get(), attrName);
2354   if (mlirAttributeIsNull(existingVisAttr))
2355     throw nb::value_error("Expected operation to have a symbol visibility.");
2356   return existingVisAttr;
2357 }
2358 
2359 void PySymbolTable::setVisibility(PyOperationBase &symbol,
2360                                   const std::string &visibility) {
2361   if (visibility != "public" && visibility != "private" &&
2362       visibility != "nested")
2363     throw nb::value_error(
2364         "Expected visibility to be 'public', 'private' or 'nested'");
2365   PyOperation &operation = symbol.getOperation();
2366   operation.checkValid();
2367   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
2368   MlirAttribute existingVisAttr =
2369       mlirOperationGetAttributeByName(operation.get(), attrName);
2370   if (mlirAttributeIsNull(existingVisAttr))
2371     throw nb::value_error("Expected operation to have a symbol visibility.");
2372   MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2373                                                toMlirStringRef(visibility));
2374   mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2375 }
2376 
2377 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2378                                          const std::string &newSymbol,
2379                                          PyOperationBase &from) {
2380   PyOperation &fromOperation = from.getOperation();
2381   fromOperation.checkValid();
2382   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
2383           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2384           from.getOperation())))
2385 
2386     throw nb::value_error("Symbol rename failed");
2387 }
2388 
2389 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
2390                                      bool allSymUsesVisible,
2391                                      nb::object callback) {
2392   PyOperation &fromOperation = from.getOperation();
2393   fromOperation.checkValid();
2394   struct UserData {
2395     PyMlirContextRef context;
2396     nb::object callback;
2397     bool gotException;
2398     std::string exceptionWhat;
2399     nb::object exceptionType;
2400   };
2401   UserData userData{
2402       fromOperation.getContext(), std::move(callback), false, {}, {}};
2403   mlirSymbolTableWalkSymbolTables(
2404       fromOperation.get(), allSymUsesVisible,
2405       [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2406         UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2407         auto pyFoundOp =
2408             PyOperation::forOperation(calleeUserData->context, foundOp);
2409         if (calleeUserData->gotException)
2410           return;
2411         try {
2412           calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2413         } catch (nb::python_error &e) {
2414           calleeUserData->gotException = true;
2415           calleeUserData->exceptionWhat = e.what();
2416           calleeUserData->exceptionType = nb::borrow(e.type());
2417         }
2418       },
2419       static_cast<void *>(&userData));
2420   if (userData.gotException) {
2421     std::string message("Exception raised in callback: ");
2422     message.append(userData.exceptionWhat);
2423     throw std::runtime_error(message);
2424   }
2425 }
2426 
2427 namespace {
2428 
2429 /// Python wrapper for MlirBlockArgument.
2430 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2431 public:
2432   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2433   static constexpr const char *pyClassName = "BlockArgument";
2434   using PyConcreteValue::PyConcreteValue;
2435 
2436   static void bindDerived(ClassTy &c) {
2437     c.def_prop_ro("owner", [](PyBlockArgument &self) {
2438       return PyBlock(self.getParentOperation(),
2439                      mlirBlockArgumentGetOwner(self.get()));
2440     });
2441     c.def_prop_ro("arg_number", [](PyBlockArgument &self) {
2442       return mlirBlockArgumentGetArgNumber(self.get());
2443     });
2444     c.def(
2445         "set_type",
2446         [](PyBlockArgument &self, PyType type) {
2447           return mlirBlockArgumentSetType(self.get(), type);
2448         },
2449         nb::arg("type"));
2450   }
2451 };
2452 
2453 /// A list of block arguments. Internally, these are stored as consecutive
2454 /// elements, random access is cheap. The argument list is associated with the
2455 /// operation that contains the block (detached blocks are not allowed in
2456 /// Python bindings) and extends its lifetime.
2457 class PyBlockArgumentList
2458     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2459 public:
2460   static constexpr const char *pyClassName = "BlockArgumentList";
2461   using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
2462 
2463   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2464                       intptr_t startIndex = 0, intptr_t length = -1,
2465                       intptr_t step = 1)
2466       : Sliceable(startIndex,
2467                   length == -1 ? mlirBlockGetNumArguments(block) : length,
2468                   step),
2469         operation(std::move(operation)), block(block) {}
2470 
2471   static void bindDerived(ClassTy &c) {
2472     c.def_prop_ro("types", [](PyBlockArgumentList &self) {
2473       return getValueTypes(self, self.operation->getContext());
2474     });
2475   }
2476 
2477 private:
2478   /// Give the parent CRTP class access to hook implementations below.
2479   friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2480 
2481   /// Returns the number of arguments in the list.
2482   intptr_t getRawNumElements() {
2483     operation->checkValid();
2484     return mlirBlockGetNumArguments(block);
2485   }
2486 
2487   /// Returns `pos`-the element in the list.
2488   PyBlockArgument getRawElement(intptr_t pos) {
2489     MlirValue argument = mlirBlockGetArgument(block, pos);
2490     return PyBlockArgument(operation, argument);
2491   }
2492 
2493   /// Returns a sublist of this list.
2494   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2495                             intptr_t step) {
2496     return PyBlockArgumentList(operation, block, startIndex, length, step);
2497   }
2498 
2499   PyOperationRef operation;
2500   MlirBlock block;
2501 };
2502 
2503 /// A list of operation operands. Internally, these are stored as consecutive
2504 /// elements, random access is cheap. The (returned) operand list is associated
2505 /// with the operation whose operands these are, and thus extends the lifetime
2506 /// of this operation.
2507 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2508 public:
2509   static constexpr const char *pyClassName = "OpOperandList";
2510   using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2511 
2512   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2513                   intptr_t length = -1, intptr_t step = 1)
2514       : Sliceable(startIndex,
2515                   length == -1 ? mlirOperationGetNumOperands(operation->get())
2516                                : length,
2517                   step),
2518         operation(operation) {}
2519 
2520   void dunderSetItem(intptr_t index, PyValue value) {
2521     index = wrapIndex(index);
2522     mlirOperationSetOperand(operation->get(), index, value.get());
2523   }
2524 
2525   static void bindDerived(ClassTy &c) {
2526     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2527   }
2528 
2529 private:
2530   /// Give the parent CRTP class access to hook implementations below.
2531   friend class Sliceable<PyOpOperandList, PyValue>;
2532 
2533   intptr_t getRawNumElements() {
2534     operation->checkValid();
2535     return mlirOperationGetNumOperands(operation->get());
2536   }
2537 
2538   PyValue getRawElement(intptr_t pos) {
2539     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2540     MlirOperation owner;
2541     if (mlirValueIsAOpResult(operand))
2542       owner = mlirOpResultGetOwner(operand);
2543     else if (mlirValueIsABlockArgument(operand))
2544       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2545     else
2546       assert(false && "Value must be an block arg or op result.");
2547     PyOperationRef pyOwner =
2548         PyOperation::forOperation(operation->getContext(), owner);
2549     return PyValue(pyOwner, operand);
2550   }
2551 
2552   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2553     return PyOpOperandList(operation, startIndex, length, step);
2554   }
2555 
2556   PyOperationRef operation;
2557 };
2558 
2559 /// A list of operation successors. Internally, these are stored as consecutive
2560 /// elements, random access is cheap. The (returned) successor list is
2561 /// associated with the operation whose successors these are, and thus extends
2562 /// the lifetime of this operation.
2563 class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2564 public:
2565   static constexpr const char *pyClassName = "OpSuccessors";
2566 
2567   PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2568                  intptr_t length = -1, intptr_t step = 1)
2569       : Sliceable(startIndex,
2570                   length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2571                                : length,
2572                   step),
2573         operation(operation) {}
2574 
2575   void dunderSetItem(intptr_t index, PyBlock block) {
2576     index = wrapIndex(index);
2577     mlirOperationSetSuccessor(operation->get(), index, block.get());
2578   }
2579 
2580   static void bindDerived(ClassTy &c) {
2581     c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
2582   }
2583 
2584 private:
2585   /// Give the parent CRTP class access to hook implementations below.
2586   friend class Sliceable<PyOpSuccessors, PyBlock>;
2587 
2588   intptr_t getRawNumElements() {
2589     operation->checkValid();
2590     return mlirOperationGetNumSuccessors(operation->get());
2591   }
2592 
2593   PyBlock getRawElement(intptr_t pos) {
2594     MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2595     return PyBlock(operation, block);
2596   }
2597 
2598   PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2599     return PyOpSuccessors(operation, startIndex, length, step);
2600   }
2601 
2602   PyOperationRef operation;
2603 };
2604 
2605 /// A list of operation attributes. Can be indexed by name, producing
2606 /// attributes, or by index, producing named attributes.
2607 class PyOpAttributeMap {
2608 public:
2609   PyOpAttributeMap(PyOperationRef operation)
2610       : operation(std::move(operation)) {}
2611 
2612   MlirAttribute dunderGetItemNamed(const std::string &name) {
2613     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2614                                                          toMlirStringRef(name));
2615     if (mlirAttributeIsNull(attr)) {
2616       throw nb::key_error("attempt to access a non-existent attribute");
2617     }
2618     return attr;
2619   }
2620 
2621   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2622     if (index < 0 || index >= dunderLen()) {
2623       throw nb::index_error("attempt to access out of bounds attribute");
2624     }
2625     MlirNamedAttribute namedAttr =
2626         mlirOperationGetAttribute(operation->get(), index);
2627     return PyNamedAttribute(
2628         namedAttr.attribute,
2629         std::string(mlirIdentifierStr(namedAttr.name).data,
2630                     mlirIdentifierStr(namedAttr.name).length));
2631   }
2632 
2633   void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2634     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2635                                     attr);
2636   }
2637 
2638   void dunderDelItem(const std::string &name) {
2639     int removed = mlirOperationRemoveAttributeByName(operation->get(),
2640                                                      toMlirStringRef(name));
2641     if (!removed)
2642       throw nb::key_error("attempt to delete a non-existent attribute");
2643   }
2644 
2645   intptr_t dunderLen() {
2646     return mlirOperationGetNumAttributes(operation->get());
2647   }
2648 
2649   bool dunderContains(const std::string &name) {
2650     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2651         operation->get(), toMlirStringRef(name)));
2652   }
2653 
2654   static void bind(nb::module_ &m) {
2655     nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
2656         .def("__contains__", &PyOpAttributeMap::dunderContains)
2657         .def("__len__", &PyOpAttributeMap::dunderLen)
2658         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2659         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2660         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2661         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2662   }
2663 
2664 private:
2665   PyOperationRef operation;
2666 };
2667 
2668 } // namespace
2669 
2670 //------------------------------------------------------------------------------
2671 // Populates the core exports of the 'ir' submodule.
2672 //------------------------------------------------------------------------------
2673 
2674 void mlir::python::populateIRCore(nb::module_ &m) {
2675   // disable leak warnings which tend to be false positives.
2676   nb::set_leak_warnings(false);
2677   //----------------------------------------------------------------------------
2678   // Enums.
2679   //----------------------------------------------------------------------------
2680   nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
2681       .value("ERROR", MlirDiagnosticError)
2682       .value("WARNING", MlirDiagnosticWarning)
2683       .value("NOTE", MlirDiagnosticNote)
2684       .value("REMARK", MlirDiagnosticRemark);
2685 
2686   nb::enum_<MlirWalkOrder>(m, "WalkOrder")
2687       .value("PRE_ORDER", MlirWalkPreOrder)
2688       .value("POST_ORDER", MlirWalkPostOrder);
2689 
2690   nb::enum_<MlirWalkResult>(m, "WalkResult")
2691       .value("ADVANCE", MlirWalkResultAdvance)
2692       .value("INTERRUPT", MlirWalkResultInterrupt)
2693       .value("SKIP", MlirWalkResultSkip);
2694 
2695   //----------------------------------------------------------------------------
2696   // Mapping of Diagnostics.
2697   //----------------------------------------------------------------------------
2698   nb::class_<PyDiagnostic>(m, "Diagnostic")
2699       .def_prop_ro("severity", &PyDiagnostic::getSeverity)
2700       .def_prop_ro("location", &PyDiagnostic::getLocation)
2701       .def_prop_ro("message", &PyDiagnostic::getMessage)
2702       .def_prop_ro("notes", &PyDiagnostic::getNotes)
2703       .def("__str__", [](PyDiagnostic &self) -> nb::str {
2704         if (!self.isValid())
2705           return nb::str("<Invalid Diagnostic>");
2706         return self.getMessage();
2707       });
2708 
2709   nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
2710       .def("__init__",
2711            [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
2712              new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
2713            })
2714       .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity)
2715       .def_ro("location", &PyDiagnostic::DiagnosticInfo::location)
2716       .def_ro("message", &PyDiagnostic::DiagnosticInfo::message)
2717       .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes)
2718       .def("__str__",
2719            [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
2720 
2721   nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
2722       .def("detach", &PyDiagnosticHandler::detach)
2723       .def_prop_ro("attached", &PyDiagnosticHandler::isAttached)
2724       .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError)
2725       .def("__enter__", &PyDiagnosticHandler::contextEnter)
2726       .def("__exit__", &PyDiagnosticHandler::contextExit,
2727            nb::arg("exc_type").none(), nb::arg("exc_value").none(),
2728            nb::arg("traceback").none());
2729 
2730   //----------------------------------------------------------------------------
2731   // Mapping of MlirContext.
2732   // Note that this is exported as _BaseContext. The containing, Python level
2733   // __init__.py will subclass it with site-specific functionality and set a
2734   // "Context" attribute on this module.
2735   //----------------------------------------------------------------------------
2736   nb::class_<PyMlirContext>(m, "_BaseContext")
2737       .def("__init__",
2738            [](PyMlirContext &self) {
2739              MlirContext context = mlirContextCreateWithThreading(false);
2740              new (&self) PyMlirContext(context);
2741            })
2742       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2743       .def("_get_context_again",
2744            [](PyMlirContext &self) {
2745              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2746              return ref.releaseObject();
2747            })
2748       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2749       .def("_get_live_operation_objects",
2750            &PyMlirContext::getLiveOperationObjects)
2751       .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2752       .def("_clear_live_operations_inside",
2753            nb::overload_cast<MlirOperation>(
2754                &PyMlirContext::clearOperationsInside))
2755       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2756       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
2757       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2758       .def("__enter__", &PyMlirContext::contextEnter)
2759       .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
2760            nb::arg("exc_value").none(), nb::arg("traceback").none())
2761       .def_prop_ro_static(
2762           "current",
2763           [](nb::object & /*class*/) {
2764             auto *context = PyThreadContextEntry::getDefaultContext();
2765             if (!context)
2766               return nb::none();
2767             return nb::cast(context);
2768           },
2769           "Gets the Context bound to the current thread or raises ValueError")
2770       .def_prop_ro(
2771           "dialects",
2772           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2773           "Gets a container for accessing dialects by name")
2774       .def_prop_ro(
2775           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2776           "Alias for 'dialect'")
2777       .def(
2778           "get_dialect_descriptor",
2779           [=](PyMlirContext &self, std::string &name) {
2780             MlirDialect dialect = mlirContextGetOrLoadDialect(
2781                 self.get(), {name.data(), name.size()});
2782             if (mlirDialectIsNull(dialect)) {
2783               throw nb::value_error(
2784                   (Twine("Dialect '") + name + "' not found").str().c_str());
2785             }
2786             return PyDialectDescriptor(self.getRef(), dialect);
2787           },
2788           nb::arg("dialect_name"),
2789           "Gets or loads a dialect by name, returning its descriptor object")
2790       .def_prop_rw(
2791           "allow_unregistered_dialects",
2792           [](PyMlirContext &self) -> bool {
2793             return mlirContextGetAllowUnregisteredDialects(self.get());
2794           },
2795           [](PyMlirContext &self, bool value) {
2796             mlirContextSetAllowUnregisteredDialects(self.get(), value);
2797           })
2798       .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2799            nb::arg("callback"),
2800            "Attaches a diagnostic handler that will receive callbacks")
2801       .def(
2802           "enable_multithreading",
2803           [](PyMlirContext &self, bool enable) {
2804             mlirContextEnableMultithreading(self.get(), enable);
2805           },
2806           nb::arg("enable"))
2807       .def(
2808           "is_registered_operation",
2809           [](PyMlirContext &self, std::string &name) {
2810             return mlirContextIsRegisteredOperation(
2811                 self.get(), MlirStringRef{name.data(), name.size()});
2812           },
2813           nb::arg("operation_name"))
2814       .def(
2815           "append_dialect_registry",
2816           [](PyMlirContext &self, PyDialectRegistry &registry) {
2817             mlirContextAppendDialectRegistry(self.get(), registry);
2818           },
2819           nb::arg("registry"))
2820       .def_prop_rw("emit_error_diagnostics", nullptr,
2821                    &PyMlirContext::setEmitErrorDiagnostics,
2822                    "Emit error diagnostics to diagnostic handlers. By default "
2823                    "error diagnostics are captured and reported through "
2824                    "MLIRError exceptions.")
2825       .def("load_all_available_dialects", [](PyMlirContext &self) {
2826         mlirContextLoadAllAvailableDialects(self.get());
2827       });
2828 
2829   //----------------------------------------------------------------------------
2830   // Mapping of PyDialectDescriptor
2831   //----------------------------------------------------------------------------
2832   nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
2833       .def_prop_ro("namespace",
2834                    [](PyDialectDescriptor &self) {
2835                      MlirStringRef ns = mlirDialectGetNamespace(self.get());
2836                      return nb::str(ns.data, ns.length);
2837                    })
2838       .def("__repr__", [](PyDialectDescriptor &self) {
2839         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2840         std::string repr("<DialectDescriptor ");
2841         repr.append(ns.data, ns.length);
2842         repr.append(">");
2843         return repr;
2844       });
2845 
2846   //----------------------------------------------------------------------------
2847   // Mapping of PyDialects
2848   //----------------------------------------------------------------------------
2849   nb::class_<PyDialects>(m, "Dialects")
2850       .def("__getitem__",
2851            [=](PyDialects &self, std::string keyName) {
2852              MlirDialect dialect =
2853                  self.getDialectForKey(keyName, /*attrError=*/false);
2854              nb::object descriptor =
2855                  nb::cast(PyDialectDescriptor{self.getContext(), dialect});
2856              return createCustomDialectWrapper(keyName, std::move(descriptor));
2857            })
2858       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2859         MlirDialect dialect =
2860             self.getDialectForKey(attrName, /*attrError=*/true);
2861         nb::object descriptor =
2862             nb::cast(PyDialectDescriptor{self.getContext(), dialect});
2863         return createCustomDialectWrapper(attrName, std::move(descriptor));
2864       });
2865 
2866   //----------------------------------------------------------------------------
2867   // Mapping of PyDialect
2868   //----------------------------------------------------------------------------
2869   nb::class_<PyDialect>(m, "Dialect")
2870       .def(nb::init<nb::object>(), nb::arg("descriptor"))
2871       .def_prop_ro("descriptor",
2872                    [](PyDialect &self) { return self.getDescriptor(); })
2873       .def("__repr__", [](nb::object self) {
2874         auto clazz = self.attr("__class__");
2875         return nb::str("<Dialect ") +
2876                self.attr("descriptor").attr("namespace") + nb::str(" (class ") +
2877                clazz.attr("__module__") + nb::str(".") +
2878                clazz.attr("__name__") + nb::str(")>");
2879       });
2880 
2881   //----------------------------------------------------------------------------
2882   // Mapping of PyDialectRegistry
2883   //----------------------------------------------------------------------------
2884   nb::class_<PyDialectRegistry>(m, "DialectRegistry")
2885       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
2886       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2887       .def(nb::init<>());
2888 
2889   //----------------------------------------------------------------------------
2890   // Mapping of Location
2891   //----------------------------------------------------------------------------
2892   nb::class_<PyLocation>(m, "Location")
2893       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2894       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2895       .def("__enter__", &PyLocation::contextEnter)
2896       .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
2897            nb::arg("exc_value").none(), nb::arg("traceback").none())
2898       .def("__eq__",
2899            [](PyLocation &self, PyLocation &other) -> bool {
2900              return mlirLocationEqual(self, other);
2901            })
2902       .def("__eq__", [](PyLocation &self, nb::object other) { return false; })
2903       .def_prop_ro_static(
2904           "current",
2905           [](nb::object & /*class*/) {
2906             auto *loc = PyThreadContextEntry::getDefaultLocation();
2907             if (!loc)
2908               throw nb::value_error("No current Location");
2909             return loc;
2910           },
2911           "Gets the Location bound to the current thread or raises ValueError")
2912       .def_static(
2913           "unknown",
2914           [](DefaultingPyMlirContext context) {
2915             return PyLocation(context->getRef(),
2916                               mlirLocationUnknownGet(context->get()));
2917           },
2918           nb::arg("context").none() = nb::none(),
2919           "Gets a Location representing an unknown location")
2920       .def_static(
2921           "callsite",
2922           [](PyLocation callee, const std::vector<PyLocation> &frames,
2923              DefaultingPyMlirContext context) {
2924             if (frames.empty())
2925               throw nb::value_error("No caller frames provided");
2926             MlirLocation caller = frames.back().get();
2927             for (const PyLocation &frame :
2928                  llvm::reverse(llvm::ArrayRef(frames).drop_back()))
2929               caller = mlirLocationCallSiteGet(frame.get(), caller);
2930             return PyLocation(context->getRef(),
2931                               mlirLocationCallSiteGet(callee.get(), caller));
2932           },
2933           nb::arg("callee"), nb::arg("frames"),
2934           nb::arg("context").none() = nb::none(),
2935           kContextGetCallSiteLocationDocstring)
2936       .def_static(
2937           "file",
2938           [](std::string filename, int line, int col,
2939              DefaultingPyMlirContext context) {
2940             return PyLocation(
2941                 context->getRef(),
2942                 mlirLocationFileLineColGet(
2943                     context->get(), toMlirStringRef(filename), line, col));
2944           },
2945           nb::arg("filename"), nb::arg("line"), nb::arg("col"),
2946           nb::arg("context").none() = nb::none(),
2947           kContextGetFileLocationDocstring)
2948       .def_static(
2949           "file",
2950           [](std::string filename, int startLine, int startCol, int endLine,
2951              int endCol, DefaultingPyMlirContext context) {
2952             return PyLocation(context->getRef(),
2953                               mlirLocationFileLineColRangeGet(
2954                                   context->get(), toMlirStringRef(filename),
2955                                   startLine, startCol, endLine, endCol));
2956           },
2957           nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
2958           nb::arg("end_line"), nb::arg("end_col"),
2959           nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring)
2960       .def_static(
2961           "fused",
2962           [](const std::vector<PyLocation> &pyLocations,
2963              std::optional<PyAttribute> metadata,
2964              DefaultingPyMlirContext context) {
2965             llvm::SmallVector<MlirLocation, 4> locations;
2966             locations.reserve(pyLocations.size());
2967             for (auto &pyLocation : pyLocations)
2968               locations.push_back(pyLocation.get());
2969             MlirLocation location = mlirLocationFusedGet(
2970                 context->get(), locations.size(), locations.data(),
2971                 metadata ? metadata->get() : MlirAttribute{0});
2972             return PyLocation(context->getRef(), location);
2973           },
2974           nb::arg("locations"), nb::arg("metadata").none() = nb::none(),
2975           nb::arg("context").none() = nb::none(),
2976           kContextGetFusedLocationDocstring)
2977       .def_static(
2978           "name",
2979           [](std::string name, std::optional<PyLocation> childLoc,
2980              DefaultingPyMlirContext context) {
2981             return PyLocation(
2982                 context->getRef(),
2983                 mlirLocationNameGet(
2984                     context->get(), toMlirStringRef(name),
2985                     childLoc ? childLoc->get()
2986                              : mlirLocationUnknownGet(context->get())));
2987           },
2988           nb::arg("name"), nb::arg("childLoc").none() = nb::none(),
2989           nb::arg("context").none() = nb::none(),
2990           kContextGetNameLocationDocString)
2991       .def_static(
2992           "from_attr",
2993           [](PyAttribute &attribute, DefaultingPyMlirContext context) {
2994             return PyLocation(context->getRef(),
2995                               mlirLocationFromAttribute(attribute));
2996           },
2997           nb::arg("attribute"), nb::arg("context").none() = nb::none(),
2998           "Gets a Location from a LocationAttr")
2999       .def_prop_ro(
3000           "context",
3001           [](PyLocation &self) { return self.getContext().getObject(); },
3002           "Context that owns the Location")
3003       .def_prop_ro(
3004           "attr",
3005           [](PyLocation &self) { return mlirLocationGetAttribute(self); },
3006           "Get the underlying LocationAttr")
3007       .def(
3008           "emit_error",
3009           [](PyLocation &self, std::string message) {
3010             mlirEmitError(self, message.c_str());
3011           },
3012           nb::arg("message"), "Emits an error at this location")
3013       .def("__repr__", [](PyLocation &self) {
3014         PyPrintAccumulator printAccum;
3015         mlirLocationPrint(self, printAccum.getCallback(),
3016                           printAccum.getUserData());
3017         return printAccum.join();
3018       });
3019 
3020   //----------------------------------------------------------------------------
3021   // Mapping of Module
3022   //----------------------------------------------------------------------------
3023   nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
3024       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
3025       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
3026       .def_static(
3027           "parse",
3028           [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
3029             PyMlirContext::ErrorCapture errors(context->getRef());
3030             MlirModule module = mlirModuleCreateParse(
3031                 context->get(), toMlirStringRef(moduleAsm));
3032             if (mlirModuleIsNull(module))
3033               throw MLIRError("Unable to parse module assembly", errors.take());
3034             return PyModule::forModule(module).releaseObject();
3035           },
3036           nb::arg("asm"), nb::arg("context").none() = nb::none(),
3037           kModuleParseDocstring)
3038       .def_static(
3039           "parse",
3040           [](nb::bytes moduleAsm, DefaultingPyMlirContext context) {
3041             PyMlirContext::ErrorCapture errors(context->getRef());
3042             MlirModule module = mlirModuleCreateParse(
3043                 context->get(), toMlirStringRef(moduleAsm));
3044             if (mlirModuleIsNull(module))
3045               throw MLIRError("Unable to parse module assembly", errors.take());
3046             return PyModule::forModule(module).releaseObject();
3047           },
3048           nb::arg("asm"), nb::arg("context").none() = nb::none(),
3049           kModuleParseDocstring)
3050       .def_static(
3051           "create",
3052           [](DefaultingPyLocation loc) {
3053             MlirModule module = mlirModuleCreateEmpty(loc);
3054             return PyModule::forModule(module).releaseObject();
3055           },
3056           nb::arg("loc").none() = nb::none(), "Creates an empty module")
3057       .def_prop_ro(
3058           "context",
3059           [](PyModule &self) { return self.getContext().getObject(); },
3060           "Context that created the Module")
3061       .def_prop_ro(
3062           "operation",
3063           [](PyModule &self) {
3064             return PyOperation::forOperation(self.getContext(),
3065                                              mlirModuleGetOperation(self.get()),
3066                                              self.getRef().releaseObject())
3067                 .releaseObject();
3068           },
3069           "Accesses the module as an operation")
3070       .def_prop_ro(
3071           "body",
3072           [](PyModule &self) {
3073             PyOperationRef moduleOp = PyOperation::forOperation(
3074                 self.getContext(), mlirModuleGetOperation(self.get()),
3075                 self.getRef().releaseObject());
3076             PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
3077             return returnBlock;
3078           },
3079           "Return the block for this module")
3080       .def(
3081           "dump",
3082           [](PyModule &self) {
3083             mlirOperationDump(mlirModuleGetOperation(self.get()));
3084           },
3085           kDumpDocstring)
3086       .def(
3087           "__str__",
3088           [](nb::object self) {
3089             // Defer to the operation's __str__.
3090             return self.attr("operation").attr("__str__")();
3091           },
3092           kOperationStrDunderDocstring);
3093 
3094   //----------------------------------------------------------------------------
3095   // Mapping of Operation.
3096   //----------------------------------------------------------------------------
3097   nb::class_<PyOperationBase>(m, "_OperationBase")
3098       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
3099                    [](PyOperationBase &self) {
3100                      return self.getOperation().getCapsule();
3101                    })
3102       .def("__eq__",
3103            [](PyOperationBase &self, PyOperationBase &other) {
3104              return &self.getOperation() == &other.getOperation();
3105            })
3106       .def("__eq__",
3107            [](PyOperationBase &self, nb::object other) { return false; })
3108       .def("__hash__",
3109            [](PyOperationBase &self) {
3110              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
3111            })
3112       .def_prop_ro("attributes",
3113                    [](PyOperationBase &self) {
3114                      return PyOpAttributeMap(self.getOperation().getRef());
3115                    })
3116       .def_prop_ro(
3117           "context",
3118           [](PyOperationBase &self) {
3119             PyOperation &concreteOperation = self.getOperation();
3120             concreteOperation.checkValid();
3121             return concreteOperation.getContext().getObject();
3122           },
3123           "Context that owns the Operation")
3124       .def_prop_ro("name",
3125                    [](PyOperationBase &self) {
3126                      auto &concreteOperation = self.getOperation();
3127                      concreteOperation.checkValid();
3128                      MlirOperation operation = concreteOperation.get();
3129                      MlirStringRef name =
3130                          mlirIdentifierStr(mlirOperationGetName(operation));
3131                      return nb::str(name.data, name.length);
3132                    })
3133       .def_prop_ro("operands",
3134                    [](PyOperationBase &self) {
3135                      return PyOpOperandList(self.getOperation().getRef());
3136                    })
3137       .def_prop_ro("regions",
3138                    [](PyOperationBase &self) {
3139                      return PyRegionList(self.getOperation().getRef());
3140                    })
3141       .def_prop_ro(
3142           "results",
3143           [](PyOperationBase &self) {
3144             return PyOpResultList(self.getOperation().getRef());
3145           },
3146           "Returns the list of Operation results.")
3147       .def_prop_ro(
3148           "result",
3149           [](PyOperationBase &self) {
3150             auto &operation = self.getOperation();
3151             return PyOpResult(operation.getRef(), getUniqueResult(operation))
3152                 .maybeDownCast();
3153           },
3154           "Shortcut to get an op result if it has only one (throws an error "
3155           "otherwise).")
3156       .def_prop_ro(
3157           "location",
3158           [](PyOperationBase &self) {
3159             PyOperation &operation = self.getOperation();
3160             return PyLocation(operation.getContext(),
3161                               mlirOperationGetLocation(operation.get()));
3162           },
3163           "Returns the source location the operation was defined or derived "
3164           "from.")
3165       .def_prop_ro("parent",
3166                    [](PyOperationBase &self) -> nb::object {
3167                      auto parent = self.getOperation().getParentOperation();
3168                      if (parent)
3169                        return parent->getObject();
3170                      return nb::none();
3171                    })
3172       .def(
3173           "__str__",
3174           [](PyOperationBase &self) {
3175             return self.getAsm(/*binary=*/false,
3176                                /*largeElementsLimit=*/std::nullopt,
3177                                /*enableDebugInfo=*/false,
3178                                /*prettyDebugInfo=*/false,
3179                                /*printGenericOpForm=*/false,
3180                                /*useLocalScope=*/false,
3181                                /*assumeVerified=*/false,
3182                                /*skipRegions=*/false);
3183           },
3184           "Returns the assembly form of the operation.")
3185       .def("print",
3186            nb::overload_cast<PyAsmState &, nb::object, bool>(
3187                &PyOperationBase::print),
3188            nb::arg("state"), nb::arg("file").none() = nb::none(),
3189            nb::arg("binary") = false, kOperationPrintStateDocstring)
3190       .def("print",
3191            nb::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
3192                              bool, nb::object, bool, bool>(
3193                &PyOperationBase::print),
3194            // Careful: Lots of arguments must match up with print method.
3195            nb::arg("large_elements_limit").none() = nb::none(),
3196            nb::arg("enable_debug_info") = false,
3197            nb::arg("pretty_debug_info") = false,
3198            nb::arg("print_generic_op_form") = false,
3199            nb::arg("use_local_scope") = false,
3200            nb::arg("assume_verified") = false,
3201            nb::arg("file").none() = nb::none(), nb::arg("binary") = false,
3202            nb::arg("skip_regions") = false, kOperationPrintDocstring)
3203       .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
3204            nb::arg("desired_version").none() = nb::none(),
3205            kOperationPrintBytecodeDocstring)
3206       .def("get_asm", &PyOperationBase::getAsm,
3207            // Careful: Lots of arguments must match up with get_asm method.
3208            nb::arg("binary") = false,
3209            nb::arg("large_elements_limit").none() = nb::none(),
3210            nb::arg("enable_debug_info") = false,
3211            nb::arg("pretty_debug_info") = false,
3212            nb::arg("print_generic_op_form") = false,
3213            nb::arg("use_local_scope") = false,
3214            nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
3215            kOperationGetAsmDocstring)
3216       .def("verify", &PyOperationBase::verify,
3217            "Verify the operation. Raises MLIRError if verification fails, and "
3218            "returns true otherwise.")
3219       .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
3220            "Puts self immediately after the other operation in its parent "
3221            "block.")
3222       .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
3223            "Puts self immediately before the other operation in its parent "
3224            "block.")
3225       .def(
3226           "clone",
3227           [](PyOperationBase &self, nb::object ip) {
3228             return self.getOperation().clone(ip);
3229           },
3230           nb::arg("ip").none() = nb::none())
3231       .def(
3232           "detach_from_parent",
3233           [](PyOperationBase &self) {
3234             PyOperation &operation = self.getOperation();
3235             operation.checkValid();
3236             if (!operation.isAttached())
3237               throw nb::value_error("Detached operation has no parent.");
3238 
3239             operation.detachFromParent();
3240             return operation.createOpView();
3241           },
3242           "Detaches the operation from its parent block.")
3243       .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3244       .def("walk", &PyOperationBase::walk, nb::arg("callback"),
3245            nb::arg("walk_order") = MlirWalkPostOrder);
3246 
3247   nb::class_<PyOperation, PyOperationBase>(m, "Operation")
3248       .def_static(
3249           "create",
3250           [](std::string_view name,
3251              std::optional<std::vector<PyType *>> results,
3252              std::optional<std::vector<PyValue *>> operands,
3253              std::optional<nb::dict> attributes,
3254              std::optional<std::vector<PyBlock *>> successors, int regions,
3255              DefaultingPyLocation location, const nb::object &maybeIp,
3256              bool inferType) {
3257             // Unpack/validate operands.
3258             llvm::SmallVector<MlirValue, 4> mlirOperands;
3259             if (operands) {
3260               mlirOperands.reserve(operands->size());
3261               for (PyValue *operand : *operands) {
3262                 if (!operand)
3263                   throw nb::value_error("operand value cannot be None");
3264                 mlirOperands.push_back(operand->get());
3265               }
3266             }
3267 
3268             return PyOperation::create(name, results, mlirOperands, attributes,
3269                                        successors, regions, location, maybeIp,
3270                                        inferType);
3271           },
3272           nb::arg("name"), nb::arg("results").none() = nb::none(),
3273           nb::arg("operands").none() = nb::none(),
3274           nb::arg("attributes").none() = nb::none(),
3275           nb::arg("successors").none() = nb::none(), nb::arg("regions") = 0,
3276           nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
3277           nb::arg("infer_type") = false, kOperationCreateDocstring)
3278       .def_static(
3279           "parse",
3280           [](const std::string &sourceStr, const std::string &sourceName,
3281              DefaultingPyMlirContext context) {
3282             return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3283                 ->createOpView();
3284           },
3285           nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
3286           nb::arg("context").none() = nb::none(),
3287           "Parses an operation. Supports both text assembly format and binary "
3288           "bytecode format.")
3289       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule)
3290       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
3291       .def_prop_ro("operation", [](nb::object self) { return self; })
3292       .def_prop_ro("opview", &PyOperation::createOpView)
3293       .def_prop_ro(
3294           "successors",
3295           [](PyOperationBase &self) {
3296             return PyOpSuccessors(self.getOperation().getRef());
3297           },
3298           "Returns the list of Operation successors.");
3299 
3300   auto opViewClass =
3301       nb::class_<PyOpView, PyOperationBase>(m, "OpView")
3302           .def(nb::init<nb::object>(), nb::arg("operation"))
3303           .def(
3304               "__init__",
3305               [](PyOpView *self, std::string_view name,
3306                  std::tuple<int, bool> opRegionSpec,
3307                  nb::object operandSegmentSpecObj,
3308                  nb::object resultSegmentSpecObj,
3309                  std::optional<nb::list> resultTypeList, nb::list operandList,
3310                  std::optional<nb::dict> attributes,
3311                  std::optional<std::vector<PyBlock *>> successors,
3312                  std::optional<int> regions, DefaultingPyLocation location,
3313                  const nb::object &maybeIp) {
3314                 new (self) PyOpView(PyOpView::buildGeneric(
3315                     name, opRegionSpec, operandSegmentSpecObj,
3316                     resultSegmentSpecObj, resultTypeList, operandList,
3317                     attributes, successors, regions, location, maybeIp));
3318               },
3319               nb::arg("name"), nb::arg("opRegionSpec"),
3320               nb::arg("operandSegmentSpecObj").none() = nb::none(),
3321               nb::arg("resultSegmentSpecObj").none() = nb::none(),
3322               nb::arg("results").none() = nb::none(),
3323               nb::arg("operands").none() = nb::none(),
3324               nb::arg("attributes").none() = nb::none(),
3325               nb::arg("successors").none() = nb::none(),
3326               nb::arg("regions").none() = nb::none(),
3327               nb::arg("loc").none() = nb::none(),
3328               nb::arg("ip").none() = nb::none())
3329 
3330           .def_prop_ro("operation", &PyOpView::getOperationObject)
3331           .def_prop_ro("opview", [](nb::object self) { return self; })
3332           .def(
3333               "__str__",
3334               [](PyOpView &self) { return nb::str(self.getOperationObject()); })
3335           .def_prop_ro(
3336               "successors",
3337               [](PyOperationBase &self) {
3338                 return PyOpSuccessors(self.getOperation().getRef());
3339               },
3340               "Returns the list of Operation successors.");
3341   opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
3342   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
3343   opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
3344   // It is faster to pass the operation_name, ods_regions, and
3345   // ods_operand_segments/ods_result_segments as arguments to the constructor,
3346   // rather than to access them as attributes.
3347   opViewClass.attr("build_generic") = classmethod(
3348       [](nb::handle cls, std::optional<nb::list> resultTypeList,
3349          nb::list operandList, std::optional<nb::dict> attributes,
3350          std::optional<std::vector<PyBlock *>> successors,
3351          std::optional<int> regions, DefaultingPyLocation location,
3352          const nb::object &maybeIp) {
3353         std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3354         std::tuple<int, bool> opRegionSpec =
3355             nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
3356         nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
3357         nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
3358         return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
3359                                       resultSegmentSpec, resultTypeList,
3360                                       operandList, attributes, successors,
3361                                       regions, location, maybeIp);
3362       },
3363       nb::arg("cls"), nb::arg("results").none() = nb::none(),
3364       nb::arg("operands").none() = nb::none(),
3365       nb::arg("attributes").none() = nb::none(),
3366       nb::arg("successors").none() = nb::none(),
3367       nb::arg("regions").none() = nb::none(),
3368       nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
3369       "Builds a specific, generated OpView based on class level attributes.");
3370   opViewClass.attr("parse") = classmethod(
3371       [](const nb::object &cls, const std::string &sourceStr,
3372          const std::string &sourceName, DefaultingPyMlirContext context) {
3373         PyOperationRef parsed =
3374             PyOperation::parse(context->getRef(), sourceStr, sourceName);
3375 
3376         // Check if the expected operation was parsed, and cast to to the
3377         // appropriate `OpView` subclass if successful.
3378         // NOTE: This accesses attributes that have been automatically added to
3379         // `OpView` subclasses, and is not intended to be used on `OpView`
3380         // directly.
3381         std::string clsOpName =
3382             nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3383         MlirStringRef identifier =
3384             mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
3385         std::string_view parsedOpName(identifier.data, identifier.length);
3386         if (clsOpName != parsedOpName)
3387           throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
3388                           parsedOpName + "'");
3389         return PyOpView::constructDerived(cls, parsed.getObject());
3390       },
3391       nb::arg("cls"), nb::arg("source"), nb::kw_only(),
3392       nb::arg("source_name") = "", nb::arg("context").none() = nb::none(),
3393       "Parses a specific, generated OpView based on class level attributes");
3394 
3395   //----------------------------------------------------------------------------
3396   // Mapping of PyRegion.
3397   //----------------------------------------------------------------------------
3398   nb::class_<PyRegion>(m, "Region")
3399       .def_prop_ro(
3400           "blocks",
3401           [](PyRegion &self) {
3402             return PyBlockList(self.getParentOperation(), self.get());
3403           },
3404           "Returns a forward-optimized sequence of blocks.")
3405       .def_prop_ro(
3406           "owner",
3407           [](PyRegion &self) {
3408             return self.getParentOperation()->createOpView();
3409           },
3410           "Returns the operation owning this region.")
3411       .def(
3412           "__iter__",
3413           [](PyRegion &self) {
3414             self.checkValid();
3415             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3416             return PyBlockIterator(self.getParentOperation(), firstBlock);
3417           },
3418           "Iterates over blocks in the region.")
3419       .def("__eq__",
3420            [](PyRegion &self, PyRegion &other) {
3421              return self.get().ptr == other.get().ptr;
3422            })
3423       .def("__eq__", [](PyRegion &self, nb::object &other) { return false; });
3424 
3425   //----------------------------------------------------------------------------
3426   // Mapping of PyBlock.
3427   //----------------------------------------------------------------------------
3428   nb::class_<PyBlock>(m, "Block")
3429       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
3430       .def_prop_ro(
3431           "owner",
3432           [](PyBlock &self) {
3433             return self.getParentOperation()->createOpView();
3434           },
3435           "Returns the owning operation of this block.")
3436       .def_prop_ro(
3437           "region",
3438           [](PyBlock &self) {
3439             MlirRegion region = mlirBlockGetParentRegion(self.get());
3440             return PyRegion(self.getParentOperation(), region);
3441           },
3442           "Returns the owning region of this block.")
3443       .def_prop_ro(
3444           "arguments",
3445           [](PyBlock &self) {
3446             return PyBlockArgumentList(self.getParentOperation(), self.get());
3447           },
3448           "Returns a list of block arguments.")
3449       .def(
3450           "add_argument",
3451           [](PyBlock &self, const PyType &type, const PyLocation &loc) {
3452             return mlirBlockAddArgument(self.get(), type, loc);
3453           },
3454           "Append an argument of the specified type to the block and returns "
3455           "the newly added argument.")
3456       .def(
3457           "erase_argument",
3458           [](PyBlock &self, unsigned index) {
3459             return mlirBlockEraseArgument(self.get(), index);
3460           },
3461           "Erase the argument at 'index' and remove it from the argument list.")
3462       .def_prop_ro(
3463           "operations",
3464           [](PyBlock &self) {
3465             return PyOperationList(self.getParentOperation(), self.get());
3466           },
3467           "Returns a forward-optimized sequence of operations.")
3468       .def_static(
3469           "create_at_start",
3470           [](PyRegion &parent, const nb::sequence &pyArgTypes,
3471              const std::optional<nb::sequence> &pyArgLocs) {
3472             parent.checkValid();
3473             MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3474             mlirRegionInsertOwnedBlock(parent, 0, block);
3475             return PyBlock(parent.getParentOperation(), block);
3476           },
3477           nb::arg("parent"), nb::arg("arg_types") = nb::list(),
3478           nb::arg("arg_locs") = std::nullopt,
3479           "Creates and returns a new Block at the beginning of the given "
3480           "region (with given argument types and locations).")
3481       .def(
3482           "append_to",
3483           [](PyBlock &self, PyRegion &region) {
3484             MlirBlock b = self.get();
3485             if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
3486               mlirBlockDetach(b);
3487             mlirRegionAppendOwnedBlock(region.get(), b);
3488           },
3489           "Append this block to a region, transferring ownership if necessary")
3490       .def(
3491           "create_before",
3492           [](PyBlock &self, const nb::args &pyArgTypes,
3493              const std::optional<nb::sequence> &pyArgLocs) {
3494             self.checkValid();
3495             MlirBlock block =
3496                 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
3497             MlirRegion region = mlirBlockGetParentRegion(self.get());
3498             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
3499             return PyBlock(self.getParentOperation(), block);
3500           },
3501           nb::arg("arg_types"), nb::kw_only(),
3502           nb::arg("arg_locs") = std::nullopt,
3503           "Creates and returns a new Block before this block "
3504           "(with given argument types and locations).")
3505       .def(
3506           "create_after",
3507           [](PyBlock &self, const nb::args &pyArgTypes,
3508              const std::optional<nb::sequence> &pyArgLocs) {
3509             self.checkValid();
3510             MlirBlock block =
3511                 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
3512             MlirRegion region = mlirBlockGetParentRegion(self.get());
3513             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
3514             return PyBlock(self.getParentOperation(), block);
3515           },
3516           nb::arg("arg_types"), nb::kw_only(),
3517           nb::arg("arg_locs") = std::nullopt,
3518           "Creates and returns a new Block after this block "
3519           "(with given argument types and locations).")
3520       .def(
3521           "__iter__",
3522           [](PyBlock &self) {
3523             self.checkValid();
3524             MlirOperation firstOperation =
3525                 mlirBlockGetFirstOperation(self.get());
3526             return PyOperationIterator(self.getParentOperation(),
3527                                        firstOperation);
3528           },
3529           "Iterates over operations in the block.")
3530       .def("__eq__",
3531            [](PyBlock &self, PyBlock &other) {
3532              return self.get().ptr == other.get().ptr;
3533            })
3534       .def("__eq__", [](PyBlock &self, nb::object &other) { return false; })
3535       .def("__hash__",
3536            [](PyBlock &self) {
3537              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3538            })
3539       .def(
3540           "__str__",
3541           [](PyBlock &self) {
3542             self.checkValid();
3543             PyPrintAccumulator printAccum;
3544             mlirBlockPrint(self.get(), printAccum.getCallback(),
3545                            printAccum.getUserData());
3546             return printAccum.join();
3547           },
3548           "Returns the assembly form of the block.")
3549       .def(
3550           "append",
3551           [](PyBlock &self, PyOperationBase &operation) {
3552             if (operation.getOperation().isAttached())
3553               operation.getOperation().detachFromParent();
3554 
3555             MlirOperation mlirOperation = operation.getOperation().get();
3556             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
3557             operation.getOperation().setAttached(
3558                 self.getParentOperation().getObject());
3559           },
3560           nb::arg("operation"),
3561           "Appends an operation to this block. If the operation is currently "
3562           "in another block, it will be moved.");
3563 
3564   //----------------------------------------------------------------------------
3565   // Mapping of PyInsertionPoint.
3566   //----------------------------------------------------------------------------
3567 
3568   nb::class_<PyInsertionPoint>(m, "InsertionPoint")
3569       .def(nb::init<PyBlock &>(), nb::arg("block"),
3570            "Inserts after the last operation but still inside the block.")
3571       .def("__enter__", &PyInsertionPoint::contextEnter)
3572       .def("__exit__", &PyInsertionPoint::contextExit,
3573            nb::arg("exc_type").none(), nb::arg("exc_value").none(),
3574            nb::arg("traceback").none())
3575       .def_prop_ro_static(
3576           "current",
3577           [](nb::object & /*class*/) {
3578             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3579             if (!ip)
3580               throw nb::value_error("No current InsertionPoint");
3581             return ip;
3582           },
3583           "Gets the InsertionPoint bound to the current thread or raises "
3584           "ValueError if none has been set")
3585       .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
3586            "Inserts before a referenced operation.")
3587       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
3588                   nb::arg("block"), "Inserts at the beginning of the block.")
3589       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
3590                   nb::arg("block"), "Inserts before the block terminator.")
3591       .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
3592            "Inserts an operation.")
3593       .def_prop_ro(
3594           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
3595           "Returns the block that this InsertionPoint points to.")
3596       .def_prop_ro(
3597           "ref_operation",
3598           [](PyInsertionPoint &self) -> nb::object {
3599             auto refOperation = self.getRefOperation();
3600             if (refOperation)
3601               return refOperation->getObject();
3602             return nb::none();
3603           },
3604           "The reference operation before which new operations are "
3605           "inserted, or None if the insertion point is at the end of "
3606           "the block");
3607 
3608   //----------------------------------------------------------------------------
3609   // Mapping of PyAttribute.
3610   //----------------------------------------------------------------------------
3611   nb::class_<PyAttribute>(m, "Attribute")
3612       // Delegate to the PyAttribute copy constructor, which will also lifetime
3613       // extend the backing context which owns the MlirAttribute.
3614       .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
3615            "Casts the passed attribute to the generic Attribute")
3616       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule)
3617       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
3618       .def_static(
3619           "parse",
3620           [](const std::string &attrSpec, DefaultingPyMlirContext context) {
3621             PyMlirContext::ErrorCapture errors(context->getRef());
3622             MlirAttribute attr = mlirAttributeParseGet(
3623                 context->get(), toMlirStringRef(attrSpec));
3624             if (mlirAttributeIsNull(attr))
3625               throw MLIRError("Unable to parse attribute", errors.take());
3626             return attr;
3627           },
3628           nb::arg("asm"), nb::arg("context").none() = nb::none(),
3629           "Parses an attribute from an assembly form. Raises an MLIRError on "
3630           "failure.")
3631       .def_prop_ro(
3632           "context",
3633           [](PyAttribute &self) { return self.getContext().getObject(); },
3634           "Context that owns the Attribute")
3635       .def_prop_ro("type",
3636                    [](PyAttribute &self) { return mlirAttributeGetType(self); })
3637       .def(
3638           "get_named",
3639           [](PyAttribute &self, std::string name) {
3640             return PyNamedAttribute(self, std::move(name));
3641           },
3642           nb::keep_alive<0, 1>(), "Binds a name to the attribute")
3643       .def("__eq__",
3644            [](PyAttribute &self, PyAttribute &other) { return self == other; })
3645       .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; })
3646       .def("__hash__",
3647            [](PyAttribute &self) {
3648              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3649            })
3650       .def(
3651           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3652           kDumpDocstring)
3653       .def(
3654           "__str__",
3655           [](PyAttribute &self) {
3656             PyPrintAccumulator printAccum;
3657             mlirAttributePrint(self, printAccum.getCallback(),
3658                                printAccum.getUserData());
3659             return printAccum.join();
3660           },
3661           "Returns the assembly form of the Attribute.")
3662       .def("__repr__",
3663            [](PyAttribute &self) {
3664              // Generally, assembly formats are not printed for __repr__ because
3665              // this can cause exceptionally long debug output and exceptions.
3666              // However, attribute values are generally considered useful and
3667              // are printed. This may need to be re-evaluated if debug dumps end
3668              // up being excessive.
3669              PyPrintAccumulator printAccum;
3670              printAccum.parts.append("Attribute(");
3671              mlirAttributePrint(self, printAccum.getCallback(),
3672                                 printAccum.getUserData());
3673              printAccum.parts.append(")");
3674              return printAccum.join();
3675            })
3676       .def_prop_ro("typeid",
3677                    [](PyAttribute &self) -> MlirTypeID {
3678                      MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3679                      assert(!mlirTypeIDIsNull(mlirTypeID) &&
3680                             "mlirTypeID was expected to be non-null.");
3681                      return mlirTypeID;
3682                    })
3683       .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) {
3684         MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3685         assert(!mlirTypeIDIsNull(mlirTypeID) &&
3686                "mlirTypeID was expected to be non-null.");
3687         std::optional<nb::callable> typeCaster =
3688             PyGlobals::get().lookupTypeCaster(mlirTypeID,
3689                                               mlirAttributeGetDialect(self));
3690         if (!typeCaster)
3691           return nb::cast(self);
3692         return typeCaster.value()(self);
3693       });
3694 
3695   //----------------------------------------------------------------------------
3696   // Mapping of PyNamedAttribute
3697   //----------------------------------------------------------------------------
3698   nb::class_<PyNamedAttribute>(m, "NamedAttribute")
3699       .def("__repr__",
3700            [](PyNamedAttribute &self) {
3701              PyPrintAccumulator printAccum;
3702              printAccum.parts.append("NamedAttribute(");
3703              printAccum.parts.append(
3704                  nb::str(mlirIdentifierStr(self.namedAttr.name).data,
3705                          mlirIdentifierStr(self.namedAttr.name).length));
3706              printAccum.parts.append("=");
3707              mlirAttributePrint(self.namedAttr.attribute,
3708                                 printAccum.getCallback(),
3709                                 printAccum.getUserData());
3710              printAccum.parts.append(")");
3711              return printAccum.join();
3712            })
3713       .def_prop_ro(
3714           "name",
3715           [](PyNamedAttribute &self) {
3716             return nb::str(mlirIdentifierStr(self.namedAttr.name).data,
3717                            mlirIdentifierStr(self.namedAttr.name).length);
3718           },
3719           "The name of the NamedAttribute binding")
3720       .def_prop_ro(
3721           "attr",
3722           [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
3723           nb::keep_alive<0, 1>(),
3724           "The underlying generic attribute of the NamedAttribute binding");
3725 
3726   //----------------------------------------------------------------------------
3727   // Mapping of PyType.
3728   //----------------------------------------------------------------------------
3729   nb::class_<PyType>(m, "Type")
3730       // Delegate to the PyType copy constructor, which will also lifetime
3731       // extend the backing context which owns the MlirType.
3732       .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
3733            "Casts the passed type to the generic Type")
3734       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3735       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3736       .def_static(
3737           "parse",
3738           [](std::string typeSpec, DefaultingPyMlirContext context) {
3739             PyMlirContext::ErrorCapture errors(context->getRef());
3740             MlirType type =
3741                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3742             if (mlirTypeIsNull(type))
3743               throw MLIRError("Unable to parse type", errors.take());
3744             return type;
3745           },
3746           nb::arg("asm"), nb::arg("context").none() = nb::none(),
3747           kContextParseTypeDocstring)
3748       .def_prop_ro(
3749           "context", [](PyType &self) { return self.getContext().getObject(); },
3750           "Context that owns the Type")
3751       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3752       .def(
3753           "__eq__", [](PyType &self, nb::object &other) { return false; },
3754           nb::arg("other").none())
3755       .def("__hash__",
3756            [](PyType &self) {
3757              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3758            })
3759       .def(
3760           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3761       .def(
3762           "__str__",
3763           [](PyType &self) {
3764             PyPrintAccumulator printAccum;
3765             mlirTypePrint(self, printAccum.getCallback(),
3766                           printAccum.getUserData());
3767             return printAccum.join();
3768           },
3769           "Returns the assembly form of the type.")
3770       .def("__repr__",
3771            [](PyType &self) {
3772              // Generally, assembly formats are not printed for __repr__ because
3773              // this can cause exceptionally long debug output and exceptions.
3774              // However, types are an exception as they typically have compact
3775              // assembly forms and printing them is useful.
3776              PyPrintAccumulator printAccum;
3777              printAccum.parts.append("Type(");
3778              mlirTypePrint(self, printAccum.getCallback(),
3779                            printAccum.getUserData());
3780              printAccum.parts.append(")");
3781              return printAccum.join();
3782            })
3783       .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
3784            [](PyType &self) {
3785              MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3786              assert(!mlirTypeIDIsNull(mlirTypeID) &&
3787                     "mlirTypeID was expected to be non-null.");
3788              std::optional<nb::callable> typeCaster =
3789                  PyGlobals::get().lookupTypeCaster(mlirTypeID,
3790                                                    mlirTypeGetDialect(self));
3791              if (!typeCaster)
3792                return nb::cast(self);
3793              return typeCaster.value()(self);
3794            })
3795       .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID {
3796         MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3797         if (!mlirTypeIDIsNull(mlirTypeID))
3798           return mlirTypeID;
3799         auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
3800         throw nb::value_error(
3801             (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
3802       });
3803 
3804   //----------------------------------------------------------------------------
3805   // Mapping of PyTypeID.
3806   //----------------------------------------------------------------------------
3807   nb::class_<PyTypeID>(m, "TypeID")
3808       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
3809       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
3810       // Note, this tests whether the underlying TypeIDs are the same,
3811       // not whether the wrapper MlirTypeIDs are the same, nor whether
3812       // the Python objects are the same (i.e., PyTypeID is a value type).
3813       .def("__eq__",
3814            [](PyTypeID &self, PyTypeID &other) { return self == other; })
3815       .def("__eq__",
3816            [](PyTypeID &self, const nb::object &other) { return false; })
3817       // Note, this gives the hash value of the underlying TypeID, not the
3818       // hash value of the Python object, nor the hash value of the
3819       // MlirTypeID wrapper.
3820       .def("__hash__", [](PyTypeID &self) {
3821         return static_cast<size_t>(mlirTypeIDHashValue(self));
3822       });
3823 
3824   //----------------------------------------------------------------------------
3825   // Mapping of Value.
3826   //----------------------------------------------------------------------------
3827   nb::class_<PyValue>(m, "Value")
3828       .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"))
3829       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3830       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3831       .def_prop_ro(
3832           "context",
3833           [](PyValue &self) { return self.getParentOperation()->getContext(); },
3834           "Context in which the value lives.")
3835       .def(
3836           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3837           kDumpDocstring)
3838       .def_prop_ro(
3839           "owner",
3840           [](PyValue &self) -> nb::object {
3841             MlirValue v = self.get();
3842             if (mlirValueIsAOpResult(v)) {
3843               assert(
3844                   mlirOperationEqual(self.getParentOperation()->get(),
3845                                      mlirOpResultGetOwner(self.get())) &&
3846                   "expected the owner of the value in Python to match that in "
3847                   "the IR");
3848               return self.getParentOperation().getObject();
3849             }
3850 
3851             if (mlirValueIsABlockArgument(v)) {
3852               MlirBlock block = mlirBlockArgumentGetOwner(self.get());
3853               return nb::cast(PyBlock(self.getParentOperation(), block));
3854             }
3855 
3856             assert(false && "Value must be a block argument or an op result");
3857             return nb::none();
3858           })
3859       .def_prop_ro("uses",
3860                    [](PyValue &self) {
3861                      return PyOpOperandIterator(
3862                          mlirValueGetFirstUse(self.get()));
3863                    })
3864       .def("__eq__",
3865            [](PyValue &self, PyValue &other) {
3866              return self.get().ptr == other.get().ptr;
3867            })
3868       .def("__eq__", [](PyValue &self, nb::object other) { return false; })
3869       .def("__hash__",
3870            [](PyValue &self) {
3871              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3872            })
3873       .def(
3874           "__str__",
3875           [](PyValue &self) {
3876             PyPrintAccumulator printAccum;
3877             printAccum.parts.append("Value(");
3878             mlirValuePrint(self.get(), printAccum.getCallback(),
3879                            printAccum.getUserData());
3880             printAccum.parts.append(")");
3881             return printAccum.join();
3882           },
3883           kValueDunderStrDocstring)
3884       .def(
3885           "get_name",
3886           [](PyValue &self, bool useLocalScope) {
3887             PyPrintAccumulator printAccum;
3888             MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
3889             if (useLocalScope)
3890               mlirOpPrintingFlagsUseLocalScope(flags);
3891             MlirAsmState valueState =
3892                 mlirAsmStateCreateForValue(self.get(), flags);
3893             mlirValuePrintAsOperand(self.get(), valueState,
3894                                     printAccum.getCallback(),
3895                                     printAccum.getUserData());
3896             mlirOpPrintingFlagsDestroy(flags);
3897             mlirAsmStateDestroy(valueState);
3898             return printAccum.join();
3899           },
3900           nb::arg("use_local_scope") = false)
3901       .def(
3902           "get_name",
3903           [](PyValue &self, PyAsmState &state) {
3904             PyPrintAccumulator printAccum;
3905             MlirAsmState valueState = state.get();
3906             mlirValuePrintAsOperand(self.get(), valueState,
3907                                     printAccum.getCallback(),
3908                                     printAccum.getUserData());
3909             return printAccum.join();
3910           },
3911           nb::arg("state"), kGetNameAsOperand)
3912       .def_prop_ro("type",
3913                    [](PyValue &self) { return mlirValueGetType(self.get()); })
3914       .def(
3915           "set_type",
3916           [](PyValue &self, const PyType &type) {
3917             return mlirValueSetType(self.get(), type);
3918           },
3919           nb::arg("type"))
3920       .def(
3921           "replace_all_uses_with",
3922           [](PyValue &self, PyValue &with) {
3923             mlirValueReplaceAllUsesOfWith(self.get(), with.get());
3924           },
3925           kValueReplaceAllUsesWithDocstring)
3926       .def(
3927           "replace_all_uses_except",
3928           [](MlirValue self, MlirValue with, PyOperation &exception) {
3929             MlirOperation exceptedUser = exception.get();
3930             mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
3931           },
3932           nb::arg("with"), nb::arg("exceptions"),
3933           kValueReplaceAllUsesExceptDocstring)
3934       .def(
3935           "replace_all_uses_except",
3936           [](MlirValue self, MlirValue with, nb::list exceptions) {
3937             // Convert Python list to a SmallVector of MlirOperations
3938             llvm::SmallVector<MlirOperation> exceptionOps;
3939             for (nb::handle exception : exceptions) {
3940               exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
3941             }
3942 
3943             mlirValueReplaceAllUsesExcept(
3944                 self, with, static_cast<intptr_t>(exceptionOps.size()),
3945                 exceptionOps.data());
3946           },
3947           nb::arg("with"), nb::arg("exceptions"),
3948           kValueReplaceAllUsesExceptDocstring)
3949       .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
3950            [](PyValue &self) { return self.maybeDownCast(); });
3951   PyBlockArgument::bind(m);
3952   PyOpResult::bind(m);
3953   PyOpOperand::bind(m);
3954 
3955   nb::class_<PyAsmState>(m, "AsmState")
3956       .def(nb::init<PyValue &, bool>(), nb::arg("value"),
3957            nb::arg("use_local_scope") = false)
3958       .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
3959            nb::arg("use_local_scope") = false);
3960 
3961   //----------------------------------------------------------------------------
3962   // Mapping of SymbolTable.
3963   //----------------------------------------------------------------------------
3964   nb::class_<PySymbolTable>(m, "SymbolTable")
3965       .def(nb::init<PyOperationBase &>())
3966       .def("__getitem__", &PySymbolTable::dunderGetItem)
3967       .def("insert", &PySymbolTable::insert, nb::arg("operation"))
3968       .def("erase", &PySymbolTable::erase, nb::arg("operation"))
3969       .def("__delitem__", &PySymbolTable::dunderDel)
3970       .def("__contains__",
3971            [](PySymbolTable &table, const std::string &name) {
3972              return !mlirOperationIsNull(mlirSymbolTableLookup(
3973                  table, mlirStringRefCreate(name.data(), name.length())));
3974            })
3975       // Static helpers.
3976       .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3977                   nb::arg("symbol"), nb::arg("name"))
3978       .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3979                   nb::arg("symbol"))
3980       .def_static("get_visibility", &PySymbolTable::getVisibility,
3981                   nb::arg("symbol"))
3982       .def_static("set_visibility", &PySymbolTable::setVisibility,
3983                   nb::arg("symbol"), nb::arg("visibility"))
3984       .def_static("replace_all_symbol_uses",
3985                   &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
3986                   nb::arg("new_symbol"), nb::arg("from_op"))
3987       .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3988                   nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
3989                   nb::arg("callback"));
3990 
3991   // Container bindings.
3992   PyBlockArgumentList::bind(m);
3993   PyBlockIterator::bind(m);
3994   PyBlockList::bind(m);
3995   PyOperationIterator::bind(m);
3996   PyOperationList::bind(m);
3997   PyOpAttributeMap::bind(m);
3998   PyOpOperandIterator::bind(m);
3999   PyOpOperandList::bind(m);
4000   PyOpResultList::bind(m);
4001   PyOpSuccessors::bind(m);
4002   PyRegionIterator::bind(m);
4003   PyRegionList::bind(m);
4004 
4005   // Debug bindings.
4006   PyGlobalDebugFlag::bind(m);
4007 
4008   // Attribute builder getter.
4009   PyAttrBuilderMap::bind(m);
4010 
4011   nb::register_exception_translator([](const std::exception_ptr &p,
4012                                        void *payload) {
4013     // We can't define exceptions with custom fields through pybind, so instead
4014     // the exception class is defined in python and imported here.
4015     try {
4016       if (p)
4017         std::rethrow_exception(p);
4018     } catch (const MLIRError &e) {
4019       nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
4020                            .attr("MLIRError")(e.message, e.errorDiagnostics);
4021       PyErr_SetObject(PyExc_Exception, obj.ptr());
4022     }
4023   });
4024 }
4025