xref: /llvm-project/mlir/lib/Bindings/Python/IRCore.cpp (revision acde3f722ff3766f6f793884108d342b78623fe4)
1436c6c9cSStella Laurenzo //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9b56d1ec6SPeter Hawkins #include <optional>
10b56d1ec6SPeter Hawkins #include <utility>
11436c6c9cSStella Laurenzo 
12436c6c9cSStella Laurenzo #include "Globals.h"
13b56d1ec6SPeter Hawkins #include "IRModule.h"
14b56d1ec6SPeter Hawkins #include "NanobindUtils.h"
15436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
164acd8457SAlex Zinenko #include "mlir-c/Debug.h"
173ea4c501SRahul Kayaith #include "mlir-c/Diagnostics.h"
183f3d1c90SMike Urbach #include "mlir-c/IR.h"
195c90e1ffSJacques Pienaar #include "mlir-c/Support.h"
205cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h"
21b56d1ec6SPeter Hawkins #include "mlir/Bindings/Python/NanobindAdaptors.h"
225cd42747SPeter Hawkins #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
23e67cbbefSJacques Pienaar #include "llvm/ADT/ArrayRef.h"
24436c6c9cSStella Laurenzo #include "llvm/ADT/SmallVector.h"
25436c6c9cSStella Laurenzo 
26b56d1ec6SPeter Hawkins namespace nb = nanobind;
27b56d1ec6SPeter Hawkins using namespace nb::literals;
28436c6c9cSStella Laurenzo using namespace mlir;
29436c6c9cSStella Laurenzo using namespace mlir::python;
30436c6c9cSStella Laurenzo 
31436c6c9cSStella Laurenzo using llvm::SmallVector;
32436c6c9cSStella Laurenzo using llvm::StringRef;
33436c6c9cSStella Laurenzo using llvm::Twine;
34436c6c9cSStella Laurenzo 
35436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
36436c6c9cSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
37436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
38436c6c9cSStella Laurenzo 
39436c6c9cSStella Laurenzo static const char kContextParseTypeDocstring[] =
40436c6c9cSStella Laurenzo     R"(Parses the assembly form of a type.
41436c6c9cSStella Laurenzo 
423ea4c501SRahul Kayaith Returns a Type object or raises an MLIRError if the type cannot be parsed.
43436c6c9cSStella Laurenzo 
44436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/#type-system
45436c6c9cSStella Laurenzo )";
46436c6c9cSStella Laurenzo 
47e67cbbefSJacques Pienaar static const char kContextGetCallSiteLocationDocstring[] =
48e67cbbefSJacques Pienaar     R"(Gets a Location representing a caller and callsite)";
49e67cbbefSJacques Pienaar 
50436c6c9cSStella Laurenzo static const char kContextGetFileLocationDocstring[] =
51436c6c9cSStella Laurenzo     R"(Gets a Location representing a file, line and column)";
52436c6c9cSStella Laurenzo 
53a77250fdSJacques Pienaar static const char kContextGetFileRangeDocstring[] =
54a77250fdSJacques Pienaar     R"(Gets a Location representing a file, line and column range)";
55a77250fdSJacques Pienaar 
561ab3efacSJacques Pienaar static const char kContextGetFusedLocationDocstring[] =
571ab3efacSJacques Pienaar     R"(Gets a Location representing a fused location with optional metadata)";
581ab3efacSJacques Pienaar 
5904d76d36SJacques Pienaar static const char kContextGetNameLocationDocString[] =
6004d76d36SJacques Pienaar     R"(Gets a Location representing a named location with optional child location)";
6104d76d36SJacques Pienaar 
62436c6c9cSStella Laurenzo static const char kModuleParseDocstring[] =
63436c6c9cSStella Laurenzo     R"(Parses a module's assembly format from a string.
64436c6c9cSStella Laurenzo 
653ea4c501SRahul Kayaith Returns a new MlirModule or raises an MLIRError if the parsing fails.
66436c6c9cSStella Laurenzo 
67436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/
68436c6c9cSStella Laurenzo )";
69436c6c9cSStella Laurenzo 
70436c6c9cSStella Laurenzo static const char kOperationCreateDocstring[] =
71436c6c9cSStella Laurenzo     R"(Creates a new operation.
72436c6c9cSStella Laurenzo 
73436c6c9cSStella Laurenzo Args:
74436c6c9cSStella Laurenzo   name: Operation name (e.g. "dialect.operation").
75436c6c9cSStella Laurenzo   results: Sequence of Type representing op result types.
76436c6c9cSStella Laurenzo   attributes: Dict of str:Attribute.
77436c6c9cSStella Laurenzo   successors: List of Block for the operation's successors.
78436c6c9cSStella Laurenzo   regions: Number of regions to create.
79436c6c9cSStella Laurenzo   location: A Location object (defaults to resolve from context manager).
80436c6c9cSStella Laurenzo   ip: An InsertionPoint (defaults to resolve from context manager or set to
81436c6c9cSStella Laurenzo     False to disable insertion, even with an insertion point set in the
82436c6c9cSStella Laurenzo     context manager).
83f573bc24SJacques Pienaar   infer_type: Whether to infer result types.
84436c6c9cSStella Laurenzo Returns:
85436c6c9cSStella Laurenzo   A new "detached" Operation object. Detached operations can be added
86436c6c9cSStella Laurenzo   to blocks, which causes them to become "attached."
87436c6c9cSStella Laurenzo )";
88436c6c9cSStella Laurenzo 
89436c6c9cSStella Laurenzo static const char kOperationPrintDocstring[] =
90436c6c9cSStella Laurenzo     R"(Prints the assembly form of the operation to a file like object.
91436c6c9cSStella Laurenzo 
92436c6c9cSStella Laurenzo Args:
93436c6c9cSStella Laurenzo   file: The file like object to write to. Defaults to sys.stdout.
94436c6c9cSStella Laurenzo   binary: Whether to write bytes (True) or str (False). Defaults to False.
95436c6c9cSStella Laurenzo   large_elements_limit: Whether to elide elements attributes above this
96436c6c9cSStella Laurenzo     number of elements. Defaults to None (no limit).
97436c6c9cSStella Laurenzo   enable_debug_info: Whether to print debug/location information. Defaults
98436c6c9cSStella Laurenzo     to False.
99436c6c9cSStella Laurenzo   pretty_debug_info: Whether to format debug information for easier reading
100436c6c9cSStella Laurenzo     by a human (warning: the result is unparseable).
101436c6c9cSStella Laurenzo   print_generic_op_form: Whether to print the generic assembly forms of all
102436c6c9cSStella Laurenzo     ops. Defaults to False.
103436c6c9cSStella Laurenzo   use_local_Scope: Whether to print in a way that is more optimized for
104436c6c9cSStella Laurenzo     multi-threaded access but may not be consistent with how the overall
105436c6c9cSStella Laurenzo     module prints.
106ace1d0adSStella Laurenzo   assume_verified: By default, if not printing generic form, the verifier
107ace1d0adSStella Laurenzo     will be run and if it fails, generic form will be printed with a comment
108ace1d0adSStella Laurenzo     about failed verification. While a reasonable default for interactive use,
109ace1d0adSStella Laurenzo     for systematic use, it is often better for the caller to verify explicitly
110ace1d0adSStella Laurenzo     and report failures in a more robust fashion. Set this to True if doing this
111ace1d0adSStella Laurenzo     in order to avoid running a redundant verification. If the IR is actually
112ace1d0adSStella Laurenzo     invalid, behavior is undefined.
113abad8455SJonas Rickert   skip_regions: Whether to skip printing regions. Defaults to False.
114436c6c9cSStella Laurenzo )";
115436c6c9cSStella Laurenzo 
116204acc5cSJacques Pienaar static const char kOperationPrintStateDocstring[] =
117204acc5cSJacques Pienaar     R"(Prints the assembly form of the operation to a file like object.
118204acc5cSJacques Pienaar 
119204acc5cSJacques Pienaar Args:
120204acc5cSJacques Pienaar   file: The file like object to write to. Defaults to sys.stdout.
121204acc5cSJacques Pienaar   binary: Whether to write bytes (True) or str (False). Defaults to False.
122204acc5cSJacques Pienaar   state: AsmState capturing the operation numbering and flags.
123204acc5cSJacques Pienaar )";
124204acc5cSJacques Pienaar 
125436c6c9cSStella Laurenzo static const char kOperationGetAsmDocstring[] =
126436c6c9cSStella Laurenzo     R"(Gets the assembly form of the operation with all options available.
127436c6c9cSStella Laurenzo 
128436c6c9cSStella Laurenzo Args:
129436c6c9cSStella Laurenzo   binary: Whether to return a bytes (True) or str (False) object. Defaults to
130436c6c9cSStella Laurenzo     False.
131436c6c9cSStella Laurenzo   ... others ...: See the print() method for common keyword arguments for
132436c6c9cSStella Laurenzo     configuring the printout.
133436c6c9cSStella Laurenzo Returns:
134436c6c9cSStella Laurenzo   Either a bytes or str object, depending on the setting of the 'binary'
135436c6c9cSStella Laurenzo   argument.
136436c6c9cSStella Laurenzo )";
137436c6c9cSStella Laurenzo 
13889418ddcSMehdi Amini static const char kOperationPrintBytecodeDocstring[] =
13989418ddcSMehdi Amini     R"(Write the bytecode form of the operation to a file like object.
14089418ddcSMehdi Amini 
14189418ddcSMehdi Amini Args:
14289418ddcSMehdi Amini   file: The file like object to write to.
1430610e2f6SJacques Pienaar   desired_version: The version of bytecode to emit.
1440610e2f6SJacques Pienaar Returns:
1450610e2f6SJacques Pienaar   The bytecode writer status.
14689418ddcSMehdi Amini )";
14789418ddcSMehdi Amini 
148436c6c9cSStella Laurenzo static const char kOperationStrDunderDocstring[] =
149436c6c9cSStella Laurenzo     R"(Gets the assembly form of the operation with default options.
150436c6c9cSStella Laurenzo 
151436c6c9cSStella Laurenzo If more advanced control over the assembly formatting or I/O options is needed,
152436c6c9cSStella Laurenzo use the dedicated print or get_asm method, which supports keyword arguments to
153436c6c9cSStella Laurenzo customize behavior.
154436c6c9cSStella Laurenzo )";
155436c6c9cSStella Laurenzo 
156436c6c9cSStella Laurenzo static const char kDumpDocstring[] =
157436c6c9cSStella Laurenzo     R"(Dumps a debug representation of the object to stderr.)";
158436c6c9cSStella Laurenzo 
159436c6c9cSStella Laurenzo static const char kAppendBlockDocstring[] =
160436c6c9cSStella Laurenzo     R"(Appends a new block, with argument types as positional args.
161436c6c9cSStella Laurenzo 
162436c6c9cSStella Laurenzo Returns:
163436c6c9cSStella Laurenzo   The created block.
164436c6c9cSStella Laurenzo )";
165436c6c9cSStella Laurenzo 
166436c6c9cSStella Laurenzo static const char kValueDunderStrDocstring[] =
167436c6c9cSStella Laurenzo     R"(Returns the string form of the value.
168436c6c9cSStella Laurenzo 
169436c6c9cSStella Laurenzo If the value is a block argument, this is the assembly form of its type and the
170436c6c9cSStella Laurenzo position in the argument list. If the value is an operation result, this is
171436c6c9cSStella Laurenzo equivalent to printing the operation that produced it.
172436c6c9cSStella Laurenzo )";
173436c6c9cSStella Laurenzo 
17481233c70Smax static const char kGetNameAsOperand[] =
17581233c70Smax     R"(Returns the string form of value as an operand (i.e., the ValueID).
17681233c70Smax )";
17781233c70Smax 
1785b303f21Smax static const char kValueReplaceAllUsesWithDocstring[] =
1795b303f21Smax     R"(Replace all uses of value with the new value, updating anything in
1805b303f21Smax the IR that uses 'self' to use the other value instead.
1815b303f21Smax )";
1825b303f21Smax 
18321df3251SPerry Gibson static const char kValueReplaceAllUsesExceptDocstring[] =
18421df3251SPerry Gibson     R"("Replace all uses of this value with the 'with' value, except for those
18521df3251SPerry Gibson in 'exceptions'. 'exceptions' can be either a single operation or a list of
18621df3251SPerry Gibson operations.
18721df3251SPerry Gibson )";
18821df3251SPerry Gibson 
189436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
190436c6c9cSStella Laurenzo // Utilities.
191436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
192436c6c9cSStella Laurenzo 
1934acd8457SAlex Zinenko /// Helper for creating an @classmethod.
194436c6c9cSStella Laurenzo template <class Func, typename... Args>
195b56d1ec6SPeter Hawkins nb::object classmethod(Func f, Args... args) {
196b56d1ec6SPeter Hawkins   nb::object cf = nb::cpp_function(f, args...);
197b56d1ec6SPeter Hawkins   return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
198436c6c9cSStella Laurenzo }
199436c6c9cSStella Laurenzo 
200b56d1ec6SPeter Hawkins static nb::object
201436c6c9cSStella Laurenzo createCustomDialectWrapper(const std::string &dialectNamespace,
202b56d1ec6SPeter Hawkins                            nb::object dialectDescriptor) {
203436c6c9cSStella Laurenzo   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
204436c6c9cSStella Laurenzo   if (!dialectClass) {
205436c6c9cSStella Laurenzo     // Use the base class.
206b56d1ec6SPeter Hawkins     return nb::cast(PyDialect(std::move(dialectDescriptor)));
207436c6c9cSStella Laurenzo   }
208436c6c9cSStella Laurenzo 
209436c6c9cSStella Laurenzo   // Create the custom implementation.
210436c6c9cSStella Laurenzo   return (*dialectClass)(std::move(dialectDescriptor));
211436c6c9cSStella Laurenzo }
212436c6c9cSStella Laurenzo 
213436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
214436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
215436c6c9cSStella Laurenzo }
216436c6c9cSStella Laurenzo 
217f4125e02SPeter Hawkins static MlirStringRef toMlirStringRef(std::string_view s) {
218f4125e02SPeter Hawkins   return mlirStringRefCreate(s.data(), s.size());
219f4125e02SPeter Hawkins }
220f4125e02SPeter Hawkins 
221b56d1ec6SPeter Hawkins static MlirStringRef toMlirStringRef(const nb::bytes &s) {
222b56d1ec6SPeter Hawkins   return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
223b56d1ec6SPeter Hawkins }
224b56d1ec6SPeter Hawkins 
225514dddbeSRahul Kayaith /// Create a block, using the current location context if no locations are
226514dddbeSRahul Kayaith /// specified.
227b56d1ec6SPeter Hawkins static MlirBlock createBlock(const nb::sequence &pyArgTypes,
228b56d1ec6SPeter Hawkins                              const std::optional<nb::sequence> &pyArgLocs) {
229514dddbeSRahul Kayaith   SmallVector<MlirType> argTypes;
230b56d1ec6SPeter Hawkins   argTypes.reserve(nb::len(pyArgTypes));
231514dddbeSRahul Kayaith   for (const auto &pyType : pyArgTypes)
232b56d1ec6SPeter Hawkins     argTypes.push_back(nb::cast<PyType &>(pyType));
233514dddbeSRahul Kayaith 
234514dddbeSRahul Kayaith   SmallVector<MlirLocation> argLocs;
235514dddbeSRahul Kayaith   if (pyArgLocs) {
236b56d1ec6SPeter Hawkins     argLocs.reserve(nb::len(*pyArgLocs));
237514dddbeSRahul Kayaith     for (const auto &pyLoc : *pyArgLocs)
238b56d1ec6SPeter Hawkins       argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
239514dddbeSRahul Kayaith   } else if (!argTypes.empty()) {
240514dddbeSRahul Kayaith     argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
241514dddbeSRahul Kayaith   }
242514dddbeSRahul Kayaith 
243514dddbeSRahul Kayaith   if (argTypes.size() != argLocs.size())
244b56d1ec6SPeter Hawkins     throw nb::value_error(("Expected " + Twine(argTypes.size()) +
245514dddbeSRahul Kayaith                            " locations, got: " + Twine(argLocs.size()))
246b56d1ec6SPeter Hawkins                               .str()
247b56d1ec6SPeter Hawkins                               .c_str());
248514dddbeSRahul Kayaith   return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
249514dddbeSRahul Kayaith }
250514dddbeSRahul Kayaith 
2514acd8457SAlex Zinenko /// Wrapper for the global LLVM debugging flag.
2524acd8457SAlex Zinenko struct PyGlobalDebugFlag {
253f136c800Svfdev   static void set(nb::object &o, bool enable) {
254f136c800Svfdev     nb::ft_lock_guard lock(mutex);
255f136c800Svfdev     mlirEnableGlobalDebug(enable);
256f136c800Svfdev   }
2574acd8457SAlex Zinenko 
258f136c800Svfdev   static bool get(const nb::object &) {
259f136c800Svfdev     nb::ft_lock_guard lock(mutex);
260f136c800Svfdev     return mlirIsGlobalDebugEnabled();
261f136c800Svfdev   }
2624acd8457SAlex Zinenko 
263b56d1ec6SPeter Hawkins   static void bind(nb::module_ &m) {
2644acd8457SAlex Zinenko     // Debug flags.
265b56d1ec6SPeter Hawkins     nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
266b56d1ec6SPeter Hawkins         .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
2678f21909cSOleksandr "Alex" Zinenko                             &PyGlobalDebugFlag::set, "LLVM-wide debug flag")
2688f21909cSOleksandr "Alex" Zinenko         .def_static(
2698f21909cSOleksandr "Alex" Zinenko             "set_types",
2708f21909cSOleksandr "Alex" Zinenko             [](const std::string &type) {
271f136c800Svfdev               nb::ft_lock_guard lock(mutex);
2728f21909cSOleksandr "Alex" Zinenko               mlirSetGlobalDebugType(type.c_str());
2738f21909cSOleksandr "Alex" Zinenko             },
2748f21909cSOleksandr "Alex" Zinenko             "types"_a, "Sets specific debug types to be produced by LLVM")
2758f21909cSOleksandr "Alex" Zinenko         .def_static("set_types", [](const std::vector<std::string> &types) {
2768f21909cSOleksandr "Alex" Zinenko           std::vector<const char *> pointers;
2778f21909cSOleksandr "Alex" Zinenko           pointers.reserve(types.size());
2788f21909cSOleksandr "Alex" Zinenko           for (const std::string &str : types)
2798f21909cSOleksandr "Alex" Zinenko             pointers.push_back(str.c_str());
280f136c800Svfdev           nb::ft_lock_guard lock(mutex);
2818f21909cSOleksandr "Alex" Zinenko           mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
2828f21909cSOleksandr "Alex" Zinenko         });
2834acd8457SAlex Zinenko   }
284f136c800Svfdev 
285f136c800Svfdev private:
286f136c800Svfdev   static nb::ft_mutex mutex;
2874acd8457SAlex Zinenko };
2884acd8457SAlex Zinenko 
289f136c800Svfdev nb::ft_mutex PyGlobalDebugFlag::mutex;
290f136c800Svfdev 
291b57acb9aSJacques Pienaar struct PyAttrBuilderMap {
292b57acb9aSJacques Pienaar   static bool dunderContains(const std::string &attributeKind) {
293b57acb9aSJacques Pienaar     return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
294b57acb9aSJacques Pienaar   }
295a0f5bbcfSvfdev   static nb::callable dunderGetItemNamed(const std::string &attributeKind) {
296b57acb9aSJacques Pienaar     auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
297b57acb9aSJacques Pienaar     if (!builder)
298b56d1ec6SPeter Hawkins       throw nb::key_error(attributeKind.c_str());
299b57acb9aSJacques Pienaar     return *builder;
300b57acb9aSJacques Pienaar   }
301a0f5bbcfSvfdev   static void dunderSetItemNamed(const std::string &attributeKind,
302b56d1ec6SPeter Hawkins                                 nb::callable func, bool replace) {
30392233062Smax     PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
30492233062Smax                                               replace);
305b57acb9aSJacques Pienaar   }
306b57acb9aSJacques Pienaar 
307b56d1ec6SPeter Hawkins   static void bind(nb::module_ &m) {
308b56d1ec6SPeter Hawkins     nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
309b57acb9aSJacques Pienaar         .def_static("contains", &PyAttrBuilderMap::dunderContains)
310a0f5bbcfSvfdev         .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed)
311a0f5bbcfSvfdev         .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
31292233062Smax                     "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
31392233062Smax                     "Register an attribute builder for building MLIR "
31492233062Smax                     "attributes from python values.");
315b57acb9aSJacques Pienaar   }
316b57acb9aSJacques Pienaar };
317b57acb9aSJacques Pienaar 
318436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
319c83318e3SAdam Paszke // PyBlock
320c83318e3SAdam Paszke //------------------------------------------------------------------------------
321c83318e3SAdam Paszke 
322b56d1ec6SPeter Hawkins nb::object PyBlock::getCapsule() {
323b56d1ec6SPeter Hawkins   return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
324c83318e3SAdam Paszke }
325c83318e3SAdam Paszke 
326c83318e3SAdam Paszke //------------------------------------------------------------------------------
327436c6c9cSStella Laurenzo // Collections.
328436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
329436c6c9cSStella Laurenzo 
330436c6c9cSStella Laurenzo namespace {
331436c6c9cSStella Laurenzo 
332436c6c9cSStella Laurenzo class PyRegionIterator {
333436c6c9cSStella Laurenzo public:
334436c6c9cSStella Laurenzo   PyRegionIterator(PyOperationRef operation)
335436c6c9cSStella Laurenzo       : operation(std::move(operation)) {}
336436c6c9cSStella Laurenzo 
337436c6c9cSStella Laurenzo   PyRegionIterator &dunderIter() { return *this; }
338436c6c9cSStella Laurenzo 
339436c6c9cSStella Laurenzo   PyRegion dunderNext() {
340436c6c9cSStella Laurenzo     operation->checkValid();
341436c6c9cSStella Laurenzo     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
342b56d1ec6SPeter Hawkins       throw nb::stop_iteration();
343436c6c9cSStella Laurenzo     }
344436c6c9cSStella Laurenzo     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
345436c6c9cSStella Laurenzo     return PyRegion(operation, region);
346436c6c9cSStella Laurenzo   }
347436c6c9cSStella Laurenzo 
348b56d1ec6SPeter Hawkins   static void bind(nb::module_ &m) {
349b56d1ec6SPeter Hawkins     nb::class_<PyRegionIterator>(m, "RegionIterator")
350436c6c9cSStella Laurenzo         .def("__iter__", &PyRegionIterator::dunderIter)
351436c6c9cSStella Laurenzo         .def("__next__", &PyRegionIterator::dunderNext);
352436c6c9cSStella Laurenzo   }
353436c6c9cSStella Laurenzo 
354436c6c9cSStella Laurenzo private:
355436c6c9cSStella Laurenzo   PyOperationRef operation;
356436c6c9cSStella Laurenzo   int nextIndex = 0;
357436c6c9cSStella Laurenzo };
358436c6c9cSStella Laurenzo 
359436c6c9cSStella Laurenzo /// Regions of an op are fixed length and indexed numerically so are represented
360436c6c9cSStella Laurenzo /// with a sequence-like container.
361436c6c9cSStella Laurenzo class PyRegionList {
362436c6c9cSStella Laurenzo public:
363436c6c9cSStella Laurenzo   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
364436c6c9cSStella Laurenzo 
365d0d26ee7SRahul Kayaith   PyRegionIterator dunderIter() {
366d0d26ee7SRahul Kayaith     operation->checkValid();
367d0d26ee7SRahul Kayaith     return PyRegionIterator(operation);
368d0d26ee7SRahul Kayaith   }
369d0d26ee7SRahul Kayaith 
370436c6c9cSStella Laurenzo   intptr_t dunderLen() {
371436c6c9cSStella Laurenzo     operation->checkValid();
372436c6c9cSStella Laurenzo     return mlirOperationGetNumRegions(operation->get());
373436c6c9cSStella Laurenzo   }
374436c6c9cSStella Laurenzo 
375436c6c9cSStella Laurenzo   PyRegion dunderGetItem(intptr_t index) {
376436c6c9cSStella Laurenzo     // dunderLen checks validity.
377436c6c9cSStella Laurenzo     if (index < 0 || index >= dunderLen()) {
378b56d1ec6SPeter Hawkins       throw nb::index_error("attempt to access out of bounds region");
379436c6c9cSStella Laurenzo     }
380436c6c9cSStella Laurenzo     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
381436c6c9cSStella Laurenzo     return PyRegion(operation, region);
382436c6c9cSStella Laurenzo   }
383436c6c9cSStella Laurenzo 
384b56d1ec6SPeter Hawkins   static void bind(nb::module_ &m) {
385b56d1ec6SPeter Hawkins     nb::class_<PyRegionList>(m, "RegionSequence")
386436c6c9cSStella Laurenzo         .def("__len__", &PyRegionList::dunderLen)
387d0d26ee7SRahul Kayaith         .def("__iter__", &PyRegionList::dunderIter)
388436c6c9cSStella Laurenzo         .def("__getitem__", &PyRegionList::dunderGetItem);
389436c6c9cSStella Laurenzo   }
390436c6c9cSStella Laurenzo 
391436c6c9cSStella Laurenzo private:
392436c6c9cSStella Laurenzo   PyOperationRef operation;
393436c6c9cSStella Laurenzo };
394436c6c9cSStella Laurenzo 
395436c6c9cSStella Laurenzo class PyBlockIterator {
396436c6c9cSStella Laurenzo public:
397436c6c9cSStella Laurenzo   PyBlockIterator(PyOperationRef operation, MlirBlock next)
398436c6c9cSStella Laurenzo       : operation(std::move(operation)), next(next) {}
399436c6c9cSStella Laurenzo 
400436c6c9cSStella Laurenzo   PyBlockIterator &dunderIter() { return *this; }
401436c6c9cSStella Laurenzo 
402436c6c9cSStella Laurenzo   PyBlock dunderNext() {
403436c6c9cSStella Laurenzo     operation->checkValid();
404436c6c9cSStella Laurenzo     if (mlirBlockIsNull(next)) {
405b56d1ec6SPeter Hawkins       throw nb::stop_iteration();
406436c6c9cSStella Laurenzo     }
407436c6c9cSStella Laurenzo 
408436c6c9cSStella Laurenzo     PyBlock returnBlock(operation, next);
409436c6c9cSStella Laurenzo     next = mlirBlockGetNextInRegion(next);
410436c6c9cSStella Laurenzo     return returnBlock;
411436c6c9cSStella Laurenzo   }
412436c6c9cSStella Laurenzo 
413b56d1ec6SPeter Hawkins   static void bind(nb::module_ &m) {
414b56d1ec6SPeter Hawkins     nb::class_<PyBlockIterator>(m, "BlockIterator")
415436c6c9cSStella Laurenzo         .def("__iter__", &PyBlockIterator::dunderIter)
416436c6c9cSStella Laurenzo         .def("__next__", &PyBlockIterator::dunderNext);
417436c6c9cSStella Laurenzo   }
418436c6c9cSStella Laurenzo 
419436c6c9cSStella Laurenzo private:
420436c6c9cSStella Laurenzo   PyOperationRef operation;
421436c6c9cSStella Laurenzo   MlirBlock next;
422436c6c9cSStella Laurenzo };
423436c6c9cSStella Laurenzo 
424436c6c9cSStella Laurenzo /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
425436c6c9cSStella Laurenzo /// we present them as a more full-featured list-like container but optimize
426436c6c9cSStella Laurenzo /// it for forward iteration. Blocks are always owned by a region.
427436c6c9cSStella Laurenzo class PyBlockList {
428436c6c9cSStella Laurenzo public:
429436c6c9cSStella Laurenzo   PyBlockList(PyOperationRef operation, MlirRegion region)
430436c6c9cSStella Laurenzo       : operation(std::move(operation)), region(region) {}
431436c6c9cSStella Laurenzo 
432436c6c9cSStella Laurenzo   PyBlockIterator dunderIter() {
433436c6c9cSStella Laurenzo     operation->checkValid();
434436c6c9cSStella Laurenzo     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
435436c6c9cSStella Laurenzo   }
436436c6c9cSStella Laurenzo 
437436c6c9cSStella Laurenzo   intptr_t dunderLen() {
438436c6c9cSStella Laurenzo     operation->checkValid();
439436c6c9cSStella Laurenzo     intptr_t count = 0;
440436c6c9cSStella Laurenzo     MlirBlock block = mlirRegionGetFirstBlock(region);
441436c6c9cSStella Laurenzo     while (!mlirBlockIsNull(block)) {
442436c6c9cSStella Laurenzo       count += 1;
443436c6c9cSStella Laurenzo       block = mlirBlockGetNextInRegion(block);
444436c6c9cSStella Laurenzo     }
445436c6c9cSStella Laurenzo     return count;
446436c6c9cSStella Laurenzo   }
447436c6c9cSStella Laurenzo 
448436c6c9cSStella Laurenzo   PyBlock dunderGetItem(intptr_t index) {
449436c6c9cSStella Laurenzo     operation->checkValid();
450436c6c9cSStella Laurenzo     if (index < 0) {
451b56d1ec6SPeter Hawkins       throw nb::index_error("attempt to access out of bounds block");
452436c6c9cSStella Laurenzo     }
453436c6c9cSStella Laurenzo     MlirBlock block = mlirRegionGetFirstBlock(region);
454436c6c9cSStella Laurenzo     while (!mlirBlockIsNull(block)) {
455436c6c9cSStella Laurenzo       if (index == 0) {
456436c6c9cSStella Laurenzo         return PyBlock(operation, block);
457436c6c9cSStella Laurenzo       }
458436c6c9cSStella Laurenzo       block = mlirBlockGetNextInRegion(block);
459436c6c9cSStella Laurenzo       index -= 1;
460436c6c9cSStella Laurenzo     }
461b56d1ec6SPeter Hawkins     throw nb::index_error("attempt to access out of bounds block");
462436c6c9cSStella Laurenzo   }
463436c6c9cSStella Laurenzo 
464b56d1ec6SPeter Hawkins   PyBlock appendBlock(const nb::args &pyArgTypes,
465b56d1ec6SPeter Hawkins                       const std::optional<nb::sequence> &pyArgLocs) {
466436c6c9cSStella Laurenzo     operation->checkValid();
467b56d1ec6SPeter Hawkins     MlirBlock block =
468b56d1ec6SPeter Hawkins         createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
469436c6c9cSStella Laurenzo     mlirRegionAppendOwnedBlock(region, block);
470436c6c9cSStella Laurenzo     return PyBlock(operation, block);
471436c6c9cSStella Laurenzo   }
472436c6c9cSStella Laurenzo 
473b56d1ec6SPeter Hawkins   static void bind(nb::module_ &m) {
474b56d1ec6SPeter Hawkins     nb::class_<PyBlockList>(m, "BlockList")
475436c6c9cSStella Laurenzo         .def("__getitem__", &PyBlockList::dunderGetItem)
476436c6c9cSStella Laurenzo         .def("__iter__", &PyBlockList::dunderIter)
477436c6c9cSStella Laurenzo         .def("__len__", &PyBlockList::dunderLen)
478514dddbeSRahul Kayaith         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
479b56d1ec6SPeter Hawkins              nb::arg("args"), nb::kw_only(),
480b56d1ec6SPeter Hawkins              nb::arg("arg_locs") = std::nullopt);
481436c6c9cSStella Laurenzo   }
482436c6c9cSStella Laurenzo 
483436c6c9cSStella Laurenzo private:
484436c6c9cSStella Laurenzo   PyOperationRef operation;
485436c6c9cSStella Laurenzo   MlirRegion region;
486436c6c9cSStella Laurenzo };
487436c6c9cSStella Laurenzo 
488436c6c9cSStella Laurenzo class PyOperationIterator {
489436c6c9cSStella Laurenzo public:
490436c6c9cSStella Laurenzo   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
491436c6c9cSStella Laurenzo       : parentOperation(std::move(parentOperation)), next(next) {}
492436c6c9cSStella Laurenzo 
493436c6c9cSStella Laurenzo   PyOperationIterator &dunderIter() { return *this; }
494436c6c9cSStella Laurenzo 
495b56d1ec6SPeter Hawkins   nb::object dunderNext() {
496436c6c9cSStella Laurenzo     parentOperation->checkValid();
497436c6c9cSStella Laurenzo     if (mlirOperationIsNull(next)) {
498b56d1ec6SPeter Hawkins       throw nb::stop_iteration();
499436c6c9cSStella Laurenzo     }
500436c6c9cSStella Laurenzo 
501436c6c9cSStella Laurenzo     PyOperationRef returnOperation =
502436c6c9cSStella Laurenzo         PyOperation::forOperation(parentOperation->getContext(), next);
503436c6c9cSStella Laurenzo     next = mlirOperationGetNextInBlock(next);
504436c6c9cSStella Laurenzo     return returnOperation->createOpView();
505436c6c9cSStella Laurenzo   }
506436c6c9cSStella Laurenzo 
507b56d1ec6SPeter Hawkins   static void bind(nb::module_ &m) {
508b56d1ec6SPeter Hawkins     nb::class_<PyOperationIterator>(m, "OperationIterator")
509436c6c9cSStella Laurenzo         .def("__iter__", &PyOperationIterator::dunderIter)
510436c6c9cSStella Laurenzo         .def("__next__", &PyOperationIterator::dunderNext);
511436c6c9cSStella Laurenzo   }
512436c6c9cSStella Laurenzo 
513436c6c9cSStella Laurenzo private:
514436c6c9cSStella Laurenzo   PyOperationRef parentOperation;
515436c6c9cSStella Laurenzo   MlirOperation next;
516436c6c9cSStella Laurenzo };
517436c6c9cSStella Laurenzo 
518436c6c9cSStella Laurenzo /// Operations are exposed by the C-API as a forward-only linked list. In
519436c6c9cSStella Laurenzo /// Python, we present them as a more full-featured list-like container but
520436c6c9cSStella Laurenzo /// optimize it for forward iteration. Iterable operations are always owned
521436c6c9cSStella Laurenzo /// by a block.
522436c6c9cSStella Laurenzo class PyOperationList {
523436c6c9cSStella Laurenzo public:
524436c6c9cSStella Laurenzo   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
525436c6c9cSStella Laurenzo       : parentOperation(std::move(parentOperation)), block(block) {}
526436c6c9cSStella Laurenzo 
527436c6c9cSStella Laurenzo   PyOperationIterator dunderIter() {
528436c6c9cSStella Laurenzo     parentOperation->checkValid();
529436c6c9cSStella Laurenzo     return PyOperationIterator(parentOperation,
530436c6c9cSStella Laurenzo                                mlirBlockGetFirstOperation(block));
531436c6c9cSStella Laurenzo   }
532436c6c9cSStella Laurenzo 
533436c6c9cSStella Laurenzo   intptr_t dunderLen() {
534436c6c9cSStella Laurenzo     parentOperation->checkValid();
535436c6c9cSStella Laurenzo     intptr_t count = 0;
536436c6c9cSStella Laurenzo     MlirOperation childOp = mlirBlockGetFirstOperation(block);
537436c6c9cSStella Laurenzo     while (!mlirOperationIsNull(childOp)) {
538436c6c9cSStella Laurenzo       count += 1;
539436c6c9cSStella Laurenzo       childOp = mlirOperationGetNextInBlock(childOp);
540436c6c9cSStella Laurenzo     }
541436c6c9cSStella Laurenzo     return count;
542436c6c9cSStella Laurenzo   }
543436c6c9cSStella Laurenzo 
544b56d1ec6SPeter Hawkins   nb::object dunderGetItem(intptr_t index) {
545436c6c9cSStella Laurenzo     parentOperation->checkValid();
546436c6c9cSStella Laurenzo     if (index < 0) {
547b56d1ec6SPeter Hawkins       throw nb::index_error("attempt to access out of bounds operation");
548436c6c9cSStella Laurenzo     }
549436c6c9cSStella Laurenzo     MlirOperation childOp = mlirBlockGetFirstOperation(block);
550436c6c9cSStella Laurenzo     while (!mlirOperationIsNull(childOp)) {
551436c6c9cSStella Laurenzo       if (index == 0) {
552436c6c9cSStella Laurenzo         return PyOperation::forOperation(parentOperation->getContext(), childOp)
553436c6c9cSStella Laurenzo             ->createOpView();
554436c6c9cSStella Laurenzo       }
555436c6c9cSStella Laurenzo       childOp = mlirOperationGetNextInBlock(childOp);
556436c6c9cSStella Laurenzo       index -= 1;
557436c6c9cSStella Laurenzo     }
558b56d1ec6SPeter Hawkins     throw nb::index_error("attempt to access out of bounds operation");
559436c6c9cSStella Laurenzo   }
560436c6c9cSStella Laurenzo 
561b56d1ec6SPeter Hawkins   static void bind(nb::module_ &m) {
562b56d1ec6SPeter Hawkins     nb::class_<PyOperationList>(m, "OperationList")
563436c6c9cSStella Laurenzo         .def("__getitem__", &PyOperationList::dunderGetItem)
564436c6c9cSStella Laurenzo         .def("__iter__", &PyOperationList::dunderIter)
565436c6c9cSStella Laurenzo         .def("__len__", &PyOperationList::dunderLen);
566436c6c9cSStella Laurenzo   }
567436c6c9cSStella Laurenzo 
568436c6c9cSStella Laurenzo private:
569436c6c9cSStella Laurenzo   PyOperationRef parentOperation;
570436c6c9cSStella Laurenzo   MlirBlock block;
571436c6c9cSStella Laurenzo };
572436c6c9cSStella Laurenzo 
573afb2ed80SMike Urbach class PyOpOperand {
574afb2ed80SMike Urbach public:
575afb2ed80SMike Urbach   PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
576afb2ed80SMike Urbach 
577b56d1ec6SPeter Hawkins   nb::object getOwner() {
578afb2ed80SMike Urbach     MlirOperation owner = mlirOpOperandGetOwner(opOperand);
579afb2ed80SMike Urbach     PyMlirContextRef context =
580afb2ed80SMike Urbach         PyMlirContext::forContext(mlirOperationGetContext(owner));
581afb2ed80SMike Urbach     return PyOperation::forOperation(context, owner)->createOpView();
582afb2ed80SMike Urbach   }
583afb2ed80SMike Urbach 
584afb2ed80SMike Urbach   size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
585afb2ed80SMike Urbach 
586b56d1ec6SPeter Hawkins   static void bind(nb::module_ &m) {
587b56d1ec6SPeter Hawkins     nb::class_<PyOpOperand>(m, "OpOperand")
588b56d1ec6SPeter Hawkins         .def_prop_ro("owner", &PyOpOperand::getOwner)
589b56d1ec6SPeter Hawkins         .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber);
590afb2ed80SMike Urbach   }
591afb2ed80SMike Urbach 
592afb2ed80SMike Urbach private:
593afb2ed80SMike Urbach   MlirOpOperand opOperand;
594afb2ed80SMike Urbach };
595afb2ed80SMike Urbach 
596afb2ed80SMike Urbach class PyOpOperandIterator {
597afb2ed80SMike Urbach public:
598afb2ed80SMike Urbach   PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
599afb2ed80SMike Urbach 
600afb2ed80SMike Urbach   PyOpOperandIterator &dunderIter() { return *this; }
601afb2ed80SMike Urbach 
602afb2ed80SMike Urbach   PyOpOperand dunderNext() {
603afb2ed80SMike Urbach     if (mlirOpOperandIsNull(opOperand))
604b56d1ec6SPeter Hawkins       throw nb::stop_iteration();
605afb2ed80SMike Urbach 
606afb2ed80SMike Urbach     PyOpOperand returnOpOperand(opOperand);
607afb2ed80SMike Urbach     opOperand = mlirOpOperandGetNextUse(opOperand);
608afb2ed80SMike Urbach     return returnOpOperand;
609afb2ed80SMike Urbach   }
610afb2ed80SMike Urbach 
611b56d1ec6SPeter Hawkins   static void bind(nb::module_ &m) {
612b56d1ec6SPeter Hawkins     nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
613afb2ed80SMike Urbach         .def("__iter__", &PyOpOperandIterator::dunderIter)
614afb2ed80SMike Urbach         .def("__next__", &PyOpOperandIterator::dunderNext);
615afb2ed80SMike Urbach   }
616afb2ed80SMike Urbach 
617afb2ed80SMike Urbach private:
618afb2ed80SMike Urbach   MlirOpOperand opOperand;
619afb2ed80SMike Urbach };
620afb2ed80SMike Urbach 
621436c6c9cSStella Laurenzo } // namespace
622436c6c9cSStella Laurenzo 
623436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
624436c6c9cSStella Laurenzo // PyMlirContext
625436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
626436c6c9cSStella Laurenzo 
627436c6c9cSStella Laurenzo PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
628b56d1ec6SPeter Hawkins   nb::gil_scoped_acquire acquire;
629f136c800Svfdev   nb::ft_lock_guard lock(live_contexts_mutex);
630436c6c9cSStella Laurenzo   auto &liveContexts = getLiveContexts();
631436c6c9cSStella Laurenzo   liveContexts[context.ptr] = this;
632436c6c9cSStella Laurenzo }
633436c6c9cSStella Laurenzo 
634436c6c9cSStella Laurenzo PyMlirContext::~PyMlirContext() {
635436c6c9cSStella Laurenzo   // Note that the only public way to construct an instance is via the
636436c6c9cSStella Laurenzo   // forContext method, which always puts the associated handle into
637436c6c9cSStella Laurenzo   // liveContexts.
638b56d1ec6SPeter Hawkins   nb::gil_scoped_acquire acquire;
639f136c800Svfdev   {
640f136c800Svfdev     nb::ft_lock_guard lock(live_contexts_mutex);
641436c6c9cSStella Laurenzo     getLiveContexts().erase(context.ptr);
642f136c800Svfdev   }
643436c6c9cSStella Laurenzo   mlirContextDestroy(context);
644436c6c9cSStella Laurenzo }
645436c6c9cSStella Laurenzo 
646b56d1ec6SPeter Hawkins nb::object PyMlirContext::getCapsule() {
647b56d1ec6SPeter Hawkins   return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
648436c6c9cSStella Laurenzo }
649436c6c9cSStella Laurenzo 
650b56d1ec6SPeter Hawkins nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
651436c6c9cSStella Laurenzo   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
652436c6c9cSStella Laurenzo   if (mlirContextIsNull(rawContext))
653b56d1ec6SPeter Hawkins     throw nb::python_error();
65478bd1246SAlex Zinenko   return forContext(rawContext).releaseObject();
655436c6c9cSStella Laurenzo }
656436c6c9cSStella Laurenzo 
657436c6c9cSStella Laurenzo PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
658b56d1ec6SPeter Hawkins   nb::gil_scoped_acquire acquire;
659f136c800Svfdev   nb::ft_lock_guard lock(live_contexts_mutex);
660436c6c9cSStella Laurenzo   auto &liveContexts = getLiveContexts();
661436c6c9cSStella Laurenzo   auto it = liveContexts.find(context.ptr);
662436c6c9cSStella Laurenzo   if (it == liveContexts.end()) {
66378bd1246SAlex Zinenko     // Create.
66478bd1246SAlex Zinenko     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
665b56d1ec6SPeter Hawkins     nb::object pyRef = nb::cast(unownedContextWrapper);
666b56d1ec6SPeter Hawkins     assert(pyRef && "cast to nb::object failed");
66778bd1246SAlex Zinenko     liveContexts[context.ptr] = unownedContextWrapper;
66878bd1246SAlex Zinenko     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
669436c6c9cSStella Laurenzo   }
670436c6c9cSStella Laurenzo   // Use existing.
671b56d1ec6SPeter Hawkins   nb::object pyRef = nb::cast(it->second);
672436c6c9cSStella Laurenzo   return PyMlirContextRef(it->second, std::move(pyRef));
673436c6c9cSStella Laurenzo }
674436c6c9cSStella Laurenzo 
675f136c800Svfdev nb::ft_mutex PyMlirContext::live_contexts_mutex;
676f136c800Svfdev 
677436c6c9cSStella Laurenzo PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
678436c6c9cSStella Laurenzo   static LiveContextMap liveContexts;
679436c6c9cSStella Laurenzo   return liveContexts;
680436c6c9cSStella Laurenzo }
681436c6c9cSStella Laurenzo 
682f136c800Svfdev size_t PyMlirContext::getLiveCount() {
683f136c800Svfdev   nb::ft_lock_guard lock(live_contexts_mutex);
684f136c800Svfdev   return getLiveContexts().size();
685f136c800Svfdev }
686436c6c9cSStella Laurenzo 
687e2c49a45SPeter Hawkins size_t PyMlirContext::getLiveOperationCount() {
688e2c49a45SPeter Hawkins   nb::ft_lock_guard lock(liveOperationsMutex);
689e2c49a45SPeter Hawkins   return liveOperations.size();
690e2c49a45SPeter Hawkins }
691436c6c9cSStella Laurenzo 
692d1fdb416SJohn Demme std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
693d1fdb416SJohn Demme   std::vector<PyOperation *> liveObjects;
694e2c49a45SPeter Hawkins   nb::ft_lock_guard lock(liveOperationsMutex);
695d1fdb416SJohn Demme   for (auto &entry : liveOperations)
696d1fdb416SJohn Demme     liveObjects.push_back(entry.second.second);
697d1fdb416SJohn Demme   return liveObjects;
698d1fdb416SJohn Demme }
699d1fdb416SJohn Demme 
7006b0bed7eSJohn Demme size_t PyMlirContext::clearLiveOperations() {
701e2c49a45SPeter Hawkins 
702e2c49a45SPeter Hawkins   LiveOperationMap operations;
703e2c49a45SPeter Hawkins   {
704e2c49a45SPeter Hawkins     nb::ft_lock_guard lock(liveOperationsMutex);
705e2c49a45SPeter Hawkins     std::swap(operations, liveOperations);
706e2c49a45SPeter Hawkins   }
707e2c49a45SPeter Hawkins   for (auto &op : operations)
7086b0bed7eSJohn Demme     op.second.second->setInvalid();
709e2c49a45SPeter Hawkins   size_t numInvalidated = operations.size();
7106b0bed7eSJohn Demme   return numInvalidated;
7116b0bed7eSJohn Demme }
7126b0bed7eSJohn Demme 
713fa19ef7aSIngo Müller void PyMlirContext::clearOperation(MlirOperation op) {
714e2c49a45SPeter Hawkins   PyOperation *py_op;
715e2c49a45SPeter Hawkins   {
716e2c49a45SPeter Hawkins     nb::ft_lock_guard lock(liveOperationsMutex);
717fa19ef7aSIngo Müller     auto it = liveOperations.find(op.ptr);
718e2c49a45SPeter Hawkins     if (it == liveOperations.end()) {
719e2c49a45SPeter Hawkins       return;
720e2c49a45SPeter Hawkins     }
721e2c49a45SPeter Hawkins     py_op = it->second.second;
722fa19ef7aSIngo Müller     liveOperations.erase(it);
723fa19ef7aSIngo Müller   }
724e2c49a45SPeter Hawkins   py_op->setInvalid();
725fa19ef7aSIngo Müller }
726fa19ef7aSIngo Müller 
727fa19ef7aSIngo Müller void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
728fa19ef7aSIngo Müller   typedef struct {
729fa19ef7aSIngo Müller     PyOperation &rootOp;
730fa19ef7aSIngo Müller     bool rootSeen;
731fa19ef7aSIngo Müller   } callBackData;
732fa19ef7aSIngo Müller   callBackData data{op.getOperation(), false};
733fa19ef7aSIngo Müller   // Mark all ops below the op that the passmanager will be rooted
734fa19ef7aSIngo Müller   // at (but not op itself - note the preorder) as invalid.
735fa19ef7aSIngo Müller   MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
736fa19ef7aSIngo Müller                                                       void *userData) {
737fa19ef7aSIngo Müller     callBackData *data = static_cast<callBackData *>(userData);
738fa19ef7aSIngo Müller     if (LLVM_LIKELY(data->rootSeen))
739fa19ef7aSIngo Müller       data->rootOp.getOperation().getContext()->clearOperation(op);
740fa19ef7aSIngo Müller     else
741fa19ef7aSIngo Müller       data->rootSeen = true;
74247148832SHideto Ueno     return MlirWalkResult::MlirWalkResultAdvance;
743fa19ef7aSIngo Müller   };
744fa19ef7aSIngo Müller   mlirOperationWalk(op.getOperation(), invalidatingCallback,
745fa19ef7aSIngo Müller                     static_cast<void *>(&data), MlirWalkPreOrder);
746bdc3e6cbSMaksim Levental }
74791f11611SOleksandr "Alex" Zinenko void PyMlirContext::clearOperationsInside(MlirOperation op) {
74891f11611SOleksandr "Alex" Zinenko   PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
74991f11611SOleksandr "Alex" Zinenko   clearOperationsInside(opRef->getOperation());
75091f11611SOleksandr "Alex" Zinenko }
751bdc3e6cbSMaksim Levental 
75267897d77SOleksandr "Alex" Zinenko void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
75367897d77SOleksandr "Alex" Zinenko   MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
75467897d77SOleksandr "Alex" Zinenko                                                       void *userData) {
75567897d77SOleksandr "Alex" Zinenko     PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
75667897d77SOleksandr "Alex" Zinenko     contextRef->clearOperation(op);
75767897d77SOleksandr "Alex" Zinenko     return MlirWalkResult::MlirWalkResultAdvance;
75867897d77SOleksandr "Alex" Zinenko   };
75967897d77SOleksandr "Alex" Zinenko   mlirOperationWalk(op.getOperation(), invalidatingCallback,
76067897d77SOleksandr "Alex" Zinenko                     &op.getOperation().getContext(), MlirWalkPreOrder);
76167897d77SOleksandr "Alex" Zinenko }
76267897d77SOleksandr "Alex" Zinenko 
763436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
764436c6c9cSStella Laurenzo 
765b56d1ec6SPeter Hawkins nb::object PyMlirContext::contextEnter(nb::object context) {
766b56d1ec6SPeter Hawkins   return PyThreadContextEntry::pushContext(context);
767436c6c9cSStella Laurenzo }
768436c6c9cSStella Laurenzo 
769b56d1ec6SPeter Hawkins void PyMlirContext::contextExit(const nb::object &excType,
770b56d1ec6SPeter Hawkins                                 const nb::object &excVal,
771b56d1ec6SPeter Hawkins                                 const nb::object &excTb) {
772436c6c9cSStella Laurenzo   PyThreadContextEntry::popContext(*this);
773436c6c9cSStella Laurenzo }
774436c6c9cSStella Laurenzo 
775b56d1ec6SPeter Hawkins nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
7767ee25bc5SStella Laurenzo   // Note that ownership is transferred to the delete callback below by way of
7777ee25bc5SStella Laurenzo   // an explicit inc_ref (borrow).
7787ee25bc5SStella Laurenzo   PyDiagnosticHandler *pyHandler =
7797ee25bc5SStella Laurenzo       new PyDiagnosticHandler(get(), std::move(callback));
780b56d1ec6SPeter Hawkins   nb::object pyHandlerObject =
781b56d1ec6SPeter Hawkins       nb::cast(pyHandler, nb::rv_policy::take_ownership);
7827ee25bc5SStella Laurenzo   pyHandlerObject.inc_ref();
7837ee25bc5SStella Laurenzo 
7847ee25bc5SStella Laurenzo   // In these C callbacks, the userData is a PyDiagnosticHandler* that is
7857ee25bc5SStella Laurenzo   // guaranteed to be known to pybind.
7867ee25bc5SStella Laurenzo   auto handlerCallback =
7877ee25bc5SStella Laurenzo       +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
7887ee25bc5SStella Laurenzo     PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
789b56d1ec6SPeter Hawkins     nb::object pyDiagnosticObject =
790b56d1ec6SPeter Hawkins         nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
7917ee25bc5SStella Laurenzo 
7927ee25bc5SStella Laurenzo     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
7937ee25bc5SStella Laurenzo     bool result = false;
7947ee25bc5SStella Laurenzo     {
7957ee25bc5SStella Laurenzo       // Since this can be called from arbitrary C++ contexts, always get the
7967ee25bc5SStella Laurenzo       // gil.
797b56d1ec6SPeter Hawkins       nb::gil_scoped_acquire gil;
7987ee25bc5SStella Laurenzo       try {
799b56d1ec6SPeter Hawkins         result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
8007ee25bc5SStella Laurenzo       } catch (std::exception &e) {
8017ee25bc5SStella Laurenzo         fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
8027ee25bc5SStella Laurenzo                 e.what());
8037ee25bc5SStella Laurenzo         pyHandler->hadError = true;
8047ee25bc5SStella Laurenzo       }
8057ee25bc5SStella Laurenzo     }
8067ee25bc5SStella Laurenzo 
8077ee25bc5SStella Laurenzo     pyDiagnostic->invalidate();
8087ee25bc5SStella Laurenzo     return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
8097ee25bc5SStella Laurenzo   };
8107ee25bc5SStella Laurenzo   auto deleteCallback = +[](void *userData) {
8117ee25bc5SStella Laurenzo     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
8127ee25bc5SStella Laurenzo     assert(pyHandler->registeredID && "handler is not registered");
8137ee25bc5SStella Laurenzo     pyHandler->registeredID.reset();
8147ee25bc5SStella Laurenzo 
8157ee25bc5SStella Laurenzo     // Decrement reference, balancing the inc_ref() above.
816b56d1ec6SPeter Hawkins     nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
8177ee25bc5SStella Laurenzo     pyHandlerObject.dec_ref();
8187ee25bc5SStella Laurenzo   };
8197ee25bc5SStella Laurenzo 
8207ee25bc5SStella Laurenzo   pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
8217ee25bc5SStella Laurenzo       get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
8227ee25bc5SStella Laurenzo   return pyHandlerObject;
8237ee25bc5SStella Laurenzo }
8247ee25bc5SStella Laurenzo 
8253ea4c501SRahul Kayaith MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
8263ea4c501SRahul Kayaith                                                        void *userData) {
8273ea4c501SRahul Kayaith   auto *self = static_cast<ErrorCapture *>(userData);
8283ea4c501SRahul Kayaith   // Check if the context requested we emit errors instead of capturing them.
8293ea4c501SRahul Kayaith   if (self->ctx->emitErrorDiagnostics)
8303ea4c501SRahul Kayaith     return mlirLogicalResultFailure();
8313ea4c501SRahul Kayaith 
8323ea4c501SRahul Kayaith   if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
8333ea4c501SRahul Kayaith     return mlirLogicalResultFailure();
8343ea4c501SRahul Kayaith 
8353ea4c501SRahul Kayaith   self->errors.emplace_back(PyDiagnostic(diag).getInfo());
8363ea4c501SRahul Kayaith   return mlirLogicalResultSuccess();
8373ea4c501SRahul Kayaith }
8383ea4c501SRahul Kayaith 
839436c6c9cSStella Laurenzo PyMlirContext &DefaultingPyMlirContext::resolve() {
840436c6c9cSStella Laurenzo   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
841436c6c9cSStella Laurenzo   if (!context) {
8424811270bSmax     throw std::runtime_error(
843436c6c9cSStella Laurenzo         "An MLIR function requires a Context but none was provided in the call "
844436c6c9cSStella Laurenzo         "or from the surrounding environment. Either pass to the function with "
845436c6c9cSStella Laurenzo         "a 'context=' argument or establish a default using 'with Context():'");
846436c6c9cSStella Laurenzo   }
847436c6c9cSStella Laurenzo   return *context;
848436c6c9cSStella Laurenzo }
849436c6c9cSStella Laurenzo 
850436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
851436c6c9cSStella Laurenzo // PyThreadContextEntry management
852436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
853436c6c9cSStella Laurenzo 
854436c6c9cSStella Laurenzo std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
855436c6c9cSStella Laurenzo   static thread_local std::vector<PyThreadContextEntry> stack;
856436c6c9cSStella Laurenzo   return stack;
857436c6c9cSStella Laurenzo }
858436c6c9cSStella Laurenzo 
859436c6c9cSStella Laurenzo PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
860436c6c9cSStella Laurenzo   auto &stack = getStack();
861436c6c9cSStella Laurenzo   if (stack.empty())
862436c6c9cSStella Laurenzo     return nullptr;
863436c6c9cSStella Laurenzo   return &stack.back();
864436c6c9cSStella Laurenzo }
865436c6c9cSStella Laurenzo 
866b56d1ec6SPeter Hawkins void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
867b56d1ec6SPeter Hawkins                                 nb::object insertionPoint,
868b56d1ec6SPeter Hawkins                                 nb::object location) {
869436c6c9cSStella Laurenzo   auto &stack = getStack();
870436c6c9cSStella Laurenzo   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
871436c6c9cSStella Laurenzo                      std::move(location));
872436c6c9cSStella Laurenzo   // If the new stack has more than one entry and the context of the new top
873436c6c9cSStella Laurenzo   // entry matches the previous, copy the insertionPoint and location from the
874436c6c9cSStella Laurenzo   // previous entry if missing from the new top entry.
875436c6c9cSStella Laurenzo   if (stack.size() > 1) {
876436c6c9cSStella Laurenzo     auto &prev = *(stack.rbegin() + 1);
877436c6c9cSStella Laurenzo     auto &current = stack.back();
878436c6c9cSStella Laurenzo     if (current.context.is(prev.context)) {
879436c6c9cSStella Laurenzo       // Default non-context objects from the previous entry.
880436c6c9cSStella Laurenzo       if (!current.insertionPoint)
881436c6c9cSStella Laurenzo         current.insertionPoint = prev.insertionPoint;
882436c6c9cSStella Laurenzo       if (!current.location)
883436c6c9cSStella Laurenzo         current.location = prev.location;
884436c6c9cSStella Laurenzo     }
885436c6c9cSStella Laurenzo   }
886436c6c9cSStella Laurenzo }
887436c6c9cSStella Laurenzo 
888436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getContext() {
889436c6c9cSStella Laurenzo   if (!context)
890436c6c9cSStella Laurenzo     return nullptr;
891b56d1ec6SPeter Hawkins   return nb::cast<PyMlirContext *>(context);
892436c6c9cSStella Laurenzo }
893436c6c9cSStella Laurenzo 
894436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
895436c6c9cSStella Laurenzo   if (!insertionPoint)
896436c6c9cSStella Laurenzo     return nullptr;
897b56d1ec6SPeter Hawkins   return nb::cast<PyInsertionPoint *>(insertionPoint);
898436c6c9cSStella Laurenzo }
899436c6c9cSStella Laurenzo 
900436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getLocation() {
901436c6c9cSStella Laurenzo   if (!location)
902436c6c9cSStella Laurenzo     return nullptr;
903b56d1ec6SPeter Hawkins   return nb::cast<PyLocation *>(location);
904436c6c9cSStella Laurenzo }
905436c6c9cSStella Laurenzo 
906436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getDefaultContext() {
907436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
908436c6c9cSStella Laurenzo   return tos ? tos->getContext() : nullptr;
909436c6c9cSStella Laurenzo }
910436c6c9cSStella Laurenzo 
911436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
912436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
913436c6c9cSStella Laurenzo   return tos ? tos->getInsertionPoint() : nullptr;
914436c6c9cSStella Laurenzo }
915436c6c9cSStella Laurenzo 
916436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getDefaultLocation() {
917436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
918436c6c9cSStella Laurenzo   return tos ? tos->getLocation() : nullptr;
919436c6c9cSStella Laurenzo }
920436c6c9cSStella Laurenzo 
921b56d1ec6SPeter Hawkins nb::object PyThreadContextEntry::pushContext(nb::object context) {
922b56d1ec6SPeter Hawkins   push(FrameKind::Context, /*context=*/context,
923b56d1ec6SPeter Hawkins        /*insertionPoint=*/nb::object(),
924b56d1ec6SPeter Hawkins        /*location=*/nb::object());
925b56d1ec6SPeter Hawkins   return context;
926436c6c9cSStella Laurenzo }
927436c6c9cSStella Laurenzo 
928436c6c9cSStella Laurenzo void PyThreadContextEntry::popContext(PyMlirContext &context) {
929436c6c9cSStella Laurenzo   auto &stack = getStack();
930436c6c9cSStella Laurenzo   if (stack.empty())
9314811270bSmax     throw std::runtime_error("Unbalanced Context enter/exit");
932436c6c9cSStella Laurenzo   auto &tos = stack.back();
933436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
9344811270bSmax     throw std::runtime_error("Unbalanced Context enter/exit");
935436c6c9cSStella Laurenzo   stack.pop_back();
936436c6c9cSStella Laurenzo }
937436c6c9cSStella Laurenzo 
938b56d1ec6SPeter Hawkins nb::object
939b56d1ec6SPeter Hawkins PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
940b56d1ec6SPeter Hawkins   PyInsertionPoint &insertionPoint =
941b56d1ec6SPeter Hawkins       nb::cast<PyInsertionPoint &>(insertionPointObj);
942b56d1ec6SPeter Hawkins   nb::object contextObj =
943436c6c9cSStella Laurenzo       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
944436c6c9cSStella Laurenzo   push(FrameKind::InsertionPoint,
945436c6c9cSStella Laurenzo        /*context=*/contextObj,
946436c6c9cSStella Laurenzo        /*insertionPoint=*/insertionPointObj,
947b56d1ec6SPeter Hawkins        /*location=*/nb::object());
948436c6c9cSStella Laurenzo   return insertionPointObj;
949436c6c9cSStella Laurenzo }
950436c6c9cSStella Laurenzo 
951436c6c9cSStella Laurenzo void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
952436c6c9cSStella Laurenzo   auto &stack = getStack();
953436c6c9cSStella Laurenzo   if (stack.empty())
9544811270bSmax     throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
955436c6c9cSStella Laurenzo   auto &tos = stack.back();
956436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::InsertionPoint &&
957436c6c9cSStella Laurenzo       tos.getInsertionPoint() != &insertionPoint)
9584811270bSmax     throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
959436c6c9cSStella Laurenzo   stack.pop_back();
960436c6c9cSStella Laurenzo }
961436c6c9cSStella Laurenzo 
962b56d1ec6SPeter Hawkins nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
963b56d1ec6SPeter Hawkins   PyLocation &location = nb::cast<PyLocation &>(locationObj);
964b56d1ec6SPeter Hawkins   nb::object contextObj = location.getContext().getObject();
965436c6c9cSStella Laurenzo   push(FrameKind::Location, /*context=*/contextObj,
966b56d1ec6SPeter Hawkins        /*insertionPoint=*/nb::object(),
967436c6c9cSStella Laurenzo        /*location=*/locationObj);
968436c6c9cSStella Laurenzo   return locationObj;
969436c6c9cSStella Laurenzo }
970436c6c9cSStella Laurenzo 
971436c6c9cSStella Laurenzo void PyThreadContextEntry::popLocation(PyLocation &location) {
972436c6c9cSStella Laurenzo   auto &stack = getStack();
973436c6c9cSStella Laurenzo   if (stack.empty())
9744811270bSmax     throw std::runtime_error("Unbalanced Location enter/exit");
975436c6c9cSStella Laurenzo   auto &tos = stack.back();
976436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
9774811270bSmax     throw std::runtime_error("Unbalanced Location enter/exit");
978436c6c9cSStella Laurenzo   stack.pop_back();
979436c6c9cSStella Laurenzo }
980436c6c9cSStella Laurenzo 
981436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
9827ee25bc5SStella Laurenzo // PyDiagnostic*
9837ee25bc5SStella Laurenzo //------------------------------------------------------------------------------
9847ee25bc5SStella Laurenzo 
9857ee25bc5SStella Laurenzo void PyDiagnostic::invalidate() {
9867ee25bc5SStella Laurenzo   valid = false;
9877ee25bc5SStella Laurenzo   if (materializedNotes) {
988b56d1ec6SPeter Hawkins     for (nb::handle noteObject : *materializedNotes) {
989b56d1ec6SPeter Hawkins       PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
9907ee25bc5SStella Laurenzo       note->invalidate();
9917ee25bc5SStella Laurenzo     }
9927ee25bc5SStella Laurenzo   }
9937ee25bc5SStella Laurenzo }
9947ee25bc5SStella Laurenzo 
9957ee25bc5SStella Laurenzo PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
996b56d1ec6SPeter Hawkins                                          nb::object callback)
9977ee25bc5SStella Laurenzo     : context(context), callback(std::move(callback)) {}
9987ee25bc5SStella Laurenzo 
9996a38cbfbSMehdi Amini PyDiagnosticHandler::~PyDiagnosticHandler() = default;
10007ee25bc5SStella Laurenzo 
10017ee25bc5SStella Laurenzo void PyDiagnosticHandler::detach() {
10027ee25bc5SStella Laurenzo   if (!registeredID)
10037ee25bc5SStella Laurenzo     return;
10047ee25bc5SStella Laurenzo   MlirDiagnosticHandlerID localID = *registeredID;
10057ee25bc5SStella Laurenzo   mlirContextDetachDiagnosticHandler(context, localID);
10067ee25bc5SStella Laurenzo   assert(!registeredID && "should have unregistered");
10077ee25bc5SStella Laurenzo   // Not strictly necessary but keeps stale pointers from being around to cause
10087ee25bc5SStella Laurenzo   // issues.
10097ee25bc5SStella Laurenzo   context = {nullptr};
10107ee25bc5SStella Laurenzo }
10117ee25bc5SStella Laurenzo 
10127ee25bc5SStella Laurenzo void PyDiagnostic::checkValid() {
10137ee25bc5SStella Laurenzo   if (!valid) {
10147ee25bc5SStella Laurenzo     throw std::invalid_argument(
10157ee25bc5SStella Laurenzo         "Diagnostic is invalid (used outside of callback)");
10167ee25bc5SStella Laurenzo   }
10177ee25bc5SStella Laurenzo }
10187ee25bc5SStella Laurenzo 
10197ee25bc5SStella Laurenzo MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
10207ee25bc5SStella Laurenzo   checkValid();
10217ee25bc5SStella Laurenzo   return mlirDiagnosticGetSeverity(diagnostic);
10227ee25bc5SStella Laurenzo }
10237ee25bc5SStella Laurenzo 
10247ee25bc5SStella Laurenzo PyLocation PyDiagnostic::getLocation() {
10257ee25bc5SStella Laurenzo   checkValid();
10267ee25bc5SStella Laurenzo   MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
10277ee25bc5SStella Laurenzo   MlirContext context = mlirLocationGetContext(loc);
10287ee25bc5SStella Laurenzo   return PyLocation(PyMlirContext::forContext(context), loc);
10297ee25bc5SStella Laurenzo }
10307ee25bc5SStella Laurenzo 
1031b56d1ec6SPeter Hawkins nb::str PyDiagnostic::getMessage() {
10327ee25bc5SStella Laurenzo   checkValid();
1033b56d1ec6SPeter Hawkins   nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
10347ee25bc5SStella Laurenzo   PyFileAccumulator accum(fileObject, /*binary=*/false);
10357ee25bc5SStella Laurenzo   mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
1036b56d1ec6SPeter Hawkins   return nb::cast<nb::str>(fileObject.attr("getvalue")());
10377ee25bc5SStella Laurenzo }
10387ee25bc5SStella Laurenzo 
1039b56d1ec6SPeter Hawkins nb::tuple PyDiagnostic::getNotes() {
10407ee25bc5SStella Laurenzo   checkValid();
10417ee25bc5SStella Laurenzo   if (materializedNotes)
10427ee25bc5SStella Laurenzo     return *materializedNotes;
10437ee25bc5SStella Laurenzo   intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
1044b56d1ec6SPeter Hawkins   nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
10457ee25bc5SStella Laurenzo   for (intptr_t i = 0; i < numNotes; ++i) {
10467ee25bc5SStella Laurenzo     MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
1047b56d1ec6SPeter Hawkins     nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
1048b56d1ec6SPeter Hawkins     PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr());
10497ee25bc5SStella Laurenzo   }
1050b56d1ec6SPeter Hawkins   materializedNotes = std::move(notes);
1051b56d1ec6SPeter Hawkins 
10527ee25bc5SStella Laurenzo   return *materializedNotes;
10537ee25bc5SStella Laurenzo }
10547ee25bc5SStella Laurenzo 
10553ea4c501SRahul Kayaith PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() {
10563ea4c501SRahul Kayaith   std::vector<DiagnosticInfo> notes;
1057b56d1ec6SPeter Hawkins   for (nb::handle n : getNotes())
1058b56d1ec6SPeter Hawkins     notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
1059b56d1ec6SPeter Hawkins   return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
1060b56d1ec6SPeter Hawkins           std::move(notes)};
10613ea4c501SRahul Kayaith }
10623ea4c501SRahul Kayaith 
10637ee25bc5SStella Laurenzo //------------------------------------------------------------------------------
10645e83a5b4SStella Laurenzo // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
1065436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1066436c6c9cSStella Laurenzo 
1067436c6c9cSStella Laurenzo MlirDialect PyDialects::getDialectForKey(const std::string &key,
1068436c6c9cSStella Laurenzo                                          bool attrError) {
1069f8479d9dSRiver Riddle   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
1070f8479d9dSRiver Riddle                                                     {key.data(), key.size()});
1071436c6c9cSStella Laurenzo   if (mlirDialectIsNull(dialect)) {
10724811270bSmax     std::string msg = (Twine("Dialect '") + key + "' not found").str();
10734811270bSmax     if (attrError)
1074b56d1ec6SPeter Hawkins       throw nb::attribute_error(msg.c_str());
1075b56d1ec6SPeter Hawkins     throw nb::index_error(msg.c_str());
1076436c6c9cSStella Laurenzo   }
1077436c6c9cSStella Laurenzo   return dialect;
1078436c6c9cSStella Laurenzo }
1079436c6c9cSStella Laurenzo 
1080b56d1ec6SPeter Hawkins nb::object PyDialectRegistry::getCapsule() {
1081b56d1ec6SPeter Hawkins   return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
10825e83a5b4SStella Laurenzo }
10835e83a5b4SStella Laurenzo 
1084b56d1ec6SPeter Hawkins PyDialectRegistry PyDialectRegistry::createFromCapsule(nb::object capsule) {
10855e83a5b4SStella Laurenzo   MlirDialectRegistry rawRegistry =
10865e83a5b4SStella Laurenzo       mlirPythonCapsuleToDialectRegistry(capsule.ptr());
10875e83a5b4SStella Laurenzo   if (mlirDialectRegistryIsNull(rawRegistry))
1088b56d1ec6SPeter Hawkins     throw nb::python_error();
10895e83a5b4SStella Laurenzo   return PyDialectRegistry(rawRegistry);
10905e83a5b4SStella Laurenzo }
10915e83a5b4SStella Laurenzo 
1092436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1093436c6c9cSStella Laurenzo // PyLocation
1094436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1095436c6c9cSStella Laurenzo 
1096b56d1ec6SPeter Hawkins nb::object PyLocation::getCapsule() {
1097b56d1ec6SPeter Hawkins   return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
1098436c6c9cSStella Laurenzo }
1099436c6c9cSStella Laurenzo 
1100b56d1ec6SPeter Hawkins PyLocation PyLocation::createFromCapsule(nb::object capsule) {
1101436c6c9cSStella Laurenzo   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
1102436c6c9cSStella Laurenzo   if (mlirLocationIsNull(rawLoc))
1103b56d1ec6SPeter Hawkins     throw nb::python_error();
1104436c6c9cSStella Laurenzo   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
1105436c6c9cSStella Laurenzo                     rawLoc);
1106436c6c9cSStella Laurenzo }
1107436c6c9cSStella Laurenzo 
1108b56d1ec6SPeter Hawkins nb::object PyLocation::contextEnter(nb::object locationObj) {
1109b56d1ec6SPeter Hawkins   return PyThreadContextEntry::pushLocation(locationObj);
1110436c6c9cSStella Laurenzo }
1111436c6c9cSStella Laurenzo 
1112b56d1ec6SPeter Hawkins void PyLocation::contextExit(const nb::object &excType,
1113b56d1ec6SPeter Hawkins                              const nb::object &excVal,
1114b56d1ec6SPeter Hawkins                              const nb::object &excTb) {
1115436c6c9cSStella Laurenzo   PyThreadContextEntry::popLocation(*this);
1116436c6c9cSStella Laurenzo }
1117436c6c9cSStella Laurenzo 
1118436c6c9cSStella Laurenzo PyLocation &DefaultingPyLocation::resolve() {
1119436c6c9cSStella Laurenzo   auto *location = PyThreadContextEntry::getDefaultLocation();
1120436c6c9cSStella Laurenzo   if (!location) {
11214811270bSmax     throw std::runtime_error(
1122436c6c9cSStella Laurenzo         "An MLIR function requires a Location but none was provided in the "
1123436c6c9cSStella Laurenzo         "call or from the surrounding environment. Either pass to the function "
1124436c6c9cSStella Laurenzo         "with a 'loc=' argument or establish a default using 'with loc:'");
1125436c6c9cSStella Laurenzo   }
1126436c6c9cSStella Laurenzo   return *location;
1127436c6c9cSStella Laurenzo }
1128436c6c9cSStella Laurenzo 
1129436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1130436c6c9cSStella Laurenzo // PyModule
1131436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1132436c6c9cSStella Laurenzo 
1133436c6c9cSStella Laurenzo PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
1134436c6c9cSStella Laurenzo     : BaseContextObject(std::move(contextRef)), module(module) {}
1135436c6c9cSStella Laurenzo 
1136436c6c9cSStella Laurenzo PyModule::~PyModule() {
1137b56d1ec6SPeter Hawkins   nb::gil_scoped_acquire acquire;
1138436c6c9cSStella Laurenzo   auto &liveModules = getContext()->liveModules;
1139436c6c9cSStella Laurenzo   assert(liveModules.count(module.ptr) == 1 &&
1140436c6c9cSStella Laurenzo          "destroying module not in live map");
1141436c6c9cSStella Laurenzo   liveModules.erase(module.ptr);
1142436c6c9cSStella Laurenzo   mlirModuleDestroy(module);
1143436c6c9cSStella Laurenzo }
1144436c6c9cSStella Laurenzo 
1145436c6c9cSStella Laurenzo PyModuleRef PyModule::forModule(MlirModule module) {
1146436c6c9cSStella Laurenzo   MlirContext context = mlirModuleGetContext(module);
1147436c6c9cSStella Laurenzo   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
1148436c6c9cSStella Laurenzo 
1149b56d1ec6SPeter Hawkins   nb::gil_scoped_acquire acquire;
1150436c6c9cSStella Laurenzo   auto &liveModules = contextRef->liveModules;
1151436c6c9cSStella Laurenzo   auto it = liveModules.find(module.ptr);
1152436c6c9cSStella Laurenzo   if (it == liveModules.end()) {
1153436c6c9cSStella Laurenzo     // Create.
1154436c6c9cSStella Laurenzo     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1155436c6c9cSStella Laurenzo     // Note that the default return value policy on cast is automatic_reference,
1156436c6c9cSStella Laurenzo     // which does not take ownership (delete will not be called).
1157436c6c9cSStella Laurenzo     // Just be explicit.
1158b56d1ec6SPeter Hawkins     nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1159436c6c9cSStella Laurenzo     unownedModule->handle = pyRef;
1160436c6c9cSStella Laurenzo     liveModules[module.ptr] =
1161436c6c9cSStella Laurenzo         std::make_pair(unownedModule->handle, unownedModule);
1162436c6c9cSStella Laurenzo     return PyModuleRef(unownedModule, std::move(pyRef));
1163436c6c9cSStella Laurenzo   }
1164436c6c9cSStella Laurenzo   // Use existing.
1165436c6c9cSStella Laurenzo   PyModule *existing = it->second.second;
1166b56d1ec6SPeter Hawkins   nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1167436c6c9cSStella Laurenzo   return PyModuleRef(existing, std::move(pyRef));
1168436c6c9cSStella Laurenzo }
1169436c6c9cSStella Laurenzo 
1170b56d1ec6SPeter Hawkins nb::object PyModule::createFromCapsule(nb::object capsule) {
1171436c6c9cSStella Laurenzo   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1172436c6c9cSStella Laurenzo   if (mlirModuleIsNull(rawModule))
1173b56d1ec6SPeter Hawkins     throw nb::python_error();
1174436c6c9cSStella Laurenzo   return forModule(rawModule).releaseObject();
1175436c6c9cSStella Laurenzo }
1176436c6c9cSStella Laurenzo 
1177b56d1ec6SPeter Hawkins nb::object PyModule::getCapsule() {
1178b56d1ec6SPeter Hawkins   return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
1179436c6c9cSStella Laurenzo }
1180436c6c9cSStella Laurenzo 
1181436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1182436c6c9cSStella Laurenzo // PyOperation
1183436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1184436c6c9cSStella Laurenzo 
1185436c6c9cSStella Laurenzo PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1186436c6c9cSStella Laurenzo     : BaseContextObject(std::move(contextRef)), operation(operation) {}
1187436c6c9cSStella Laurenzo 
1188436c6c9cSStella Laurenzo PyOperation::~PyOperation() {
118949745f87SMike Urbach   // If the operation has already been invalidated there is nothing to do.
119049745f87SMike Urbach   if (!valid)
119149745f87SMike Urbach     return;
119267897d77SOleksandr "Alex" Zinenko 
119367897d77SOleksandr "Alex" Zinenko   // Otherwise, invalidate the operation and remove it from live map when it is
119467897d77SOleksandr "Alex" Zinenko   // attached.
119567897d77SOleksandr "Alex" Zinenko   if (isAttached()) {
119667897d77SOleksandr "Alex" Zinenko     getContext()->clearOperation(*this);
119767897d77SOleksandr "Alex" Zinenko   } else {
119867897d77SOleksandr "Alex" Zinenko     // And destroy it when it is detached, i.e. owned by Python, in which case
119967897d77SOleksandr "Alex" Zinenko     // all nested operations must be invalidated at removed from the live map as
120067897d77SOleksandr "Alex" Zinenko     // well.
120167897d77SOleksandr "Alex" Zinenko     erase();
1202436c6c9cSStella Laurenzo   }
1203436c6c9cSStella Laurenzo }
1204436c6c9cSStella Laurenzo 
1205e30b7030SPeter Hawkins namespace {
1206e30b7030SPeter Hawkins 
1207e30b7030SPeter Hawkins // Constructs a new object of type T in-place on the Python heap, returning a
1208e30b7030SPeter Hawkins // PyObjectRef to it, loosely analogous to std::make_shared<T>().
1209e30b7030SPeter Hawkins template <typename T, class... Args>
1210e30b7030SPeter Hawkins PyObjectRef<T> makeObjectRef(Args &&...args) {
1211e30b7030SPeter Hawkins   nb::handle type = nb::type<T>();
1212e30b7030SPeter Hawkins   nb::object instance = nb::inst_alloc(type);
1213e30b7030SPeter Hawkins   T *ptr = nb::inst_ptr<T>(instance);
1214e30b7030SPeter Hawkins   new (ptr) T(std::forward<Args>(args)...);
1215e30b7030SPeter Hawkins   nb::inst_mark_ready(instance);
1216e30b7030SPeter Hawkins   return PyObjectRef<T>(ptr, std::move(instance));
1217e30b7030SPeter Hawkins }
1218e30b7030SPeter Hawkins 
1219e30b7030SPeter Hawkins } // namespace
1220e30b7030SPeter Hawkins 
1221436c6c9cSStella Laurenzo PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1222436c6c9cSStella Laurenzo                                            MlirOperation operation,
1223b56d1ec6SPeter Hawkins                                            nb::object parentKeepAlive) {
1224436c6c9cSStella Laurenzo   // Create.
1225e30b7030SPeter Hawkins   PyOperationRef unownedOperation =
1226e30b7030SPeter Hawkins       makeObjectRef<PyOperation>(std::move(contextRef), operation);
1227e30b7030SPeter Hawkins   unownedOperation->handle = unownedOperation.getObject();
1228436c6c9cSStella Laurenzo   if (parentKeepAlive) {
1229436c6c9cSStella Laurenzo     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1230436c6c9cSStella Laurenzo   }
1231e30b7030SPeter Hawkins   return unownedOperation;
1232436c6c9cSStella Laurenzo }
1233436c6c9cSStella Laurenzo 
1234436c6c9cSStella Laurenzo PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
1235436c6c9cSStella Laurenzo                                          MlirOperation operation,
1236b56d1ec6SPeter Hawkins                                          nb::object parentKeepAlive) {
1237e2c49a45SPeter Hawkins   nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1238436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
1239436c6c9cSStella Laurenzo   auto it = liveOperations.find(operation.ptr);
1240436c6c9cSStella Laurenzo   if (it == liveOperations.end()) {
1241436c6c9cSStella Laurenzo     // Create.
1242e2c49a45SPeter Hawkins     PyOperationRef result = createInstance(std::move(contextRef), operation,
1243436c6c9cSStella Laurenzo                                            std::move(parentKeepAlive));
1244e2c49a45SPeter Hawkins     liveOperations[operation.ptr] =
1245e2c49a45SPeter Hawkins         std::make_pair(result.getObject(), result.get());
1246e2c49a45SPeter Hawkins     return result;
1247436c6c9cSStella Laurenzo   }
1248436c6c9cSStella Laurenzo   // Use existing.
1249436c6c9cSStella Laurenzo   PyOperation *existing = it->second.second;
1250b56d1ec6SPeter Hawkins   nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1251436c6c9cSStella Laurenzo   return PyOperationRef(existing, std::move(pyRef));
1252436c6c9cSStella Laurenzo }
1253436c6c9cSStella Laurenzo 
1254436c6c9cSStella Laurenzo PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
1255436c6c9cSStella Laurenzo                                            MlirOperation operation,
1256b56d1ec6SPeter Hawkins                                            nb::object parentKeepAlive) {
1257e2c49a45SPeter Hawkins   nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1258436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
1259436c6c9cSStella Laurenzo   assert(liveOperations.count(operation.ptr) == 0 &&
1260436c6c9cSStella Laurenzo          "cannot create detached operation that already exists");
1261436c6c9cSStella Laurenzo   (void)liveOperations;
1262436c6c9cSStella Laurenzo   PyOperationRef created = createInstance(std::move(contextRef), operation,
1263436c6c9cSStella Laurenzo                                           std::move(parentKeepAlive));
1264e2c49a45SPeter Hawkins   liveOperations[operation.ptr] =
1265e2c49a45SPeter Hawkins       std::make_pair(created.getObject(), created.get());
1266436c6c9cSStella Laurenzo   created->attached = false;
1267436c6c9cSStella Laurenzo   return created;
1268436c6c9cSStella Laurenzo }
1269436c6c9cSStella Laurenzo 
127037107e17Srkayaith PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
127137107e17Srkayaith                                   const std::string &sourceStr,
127237107e17Srkayaith                                   const std::string &sourceName) {
12733ea4c501SRahul Kayaith   PyMlirContext::ErrorCapture errors(contextRef);
127437107e17Srkayaith   MlirOperation op =
127537107e17Srkayaith       mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
127637107e17Srkayaith                                toMlirStringRef(sourceName));
127737107e17Srkayaith   if (mlirOperationIsNull(op))
12783ea4c501SRahul Kayaith     throw MLIRError("Unable to parse operation assembly", errors.take());
127937107e17Srkayaith   return PyOperation::createDetached(std::move(contextRef), op);
128037107e17Srkayaith }
128137107e17Srkayaith 
1282436c6c9cSStella Laurenzo void PyOperation::checkValid() const {
1283436c6c9cSStella Laurenzo   if (!valid) {
12844811270bSmax     throw std::runtime_error("the operation has been invalidated");
1285436c6c9cSStella Laurenzo   }
1286436c6c9cSStella Laurenzo }
1287436c6c9cSStella Laurenzo 
1288204acc5cSJacques Pienaar void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1289436c6c9cSStella Laurenzo                             bool enableDebugInfo, bool prettyDebugInfo,
1290ace1d0adSStella Laurenzo                             bool printGenericOpForm, bool useLocalScope,
1291b56d1ec6SPeter Hawkins                             bool assumeVerified, nb::object fileObject,
1292abad8455SJonas Rickert                             bool binary, bool skipRegions) {
1293436c6c9cSStella Laurenzo   PyOperation &operation = getOperation();
1294436c6c9cSStella Laurenzo   operation.checkValid();
1295436c6c9cSStella Laurenzo   if (fileObject.is_none())
1296b56d1ec6SPeter Hawkins     fileObject = nb::module_::import_("sys").attr("stdout");
1297436c6c9cSStella Laurenzo 
1298436c6c9cSStella Laurenzo   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1299436c6c9cSStella Laurenzo   if (largeElementsLimit)
1300436c6c9cSStella Laurenzo     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1301436c6c9cSStella Laurenzo   if (enableDebugInfo)
1302d0236611SRiver Riddle     mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1303d0236611SRiver Riddle                                        /*prettyForm=*/prettyDebugInfo);
1304436c6c9cSStella Laurenzo   if (printGenericOpForm)
1305436c6c9cSStella Laurenzo     mlirOpPrintingFlagsPrintGenericOpForm(flags);
1306bccf27d9SMark Browning   if (useLocalScope)
1307bccf27d9SMark Browning     mlirOpPrintingFlagsUseLocalScope(flags);
13082aa12583SRahul Kayaith   if (assumeVerified)
13092aa12583SRahul Kayaith     mlirOpPrintingFlagsAssumeVerified(flags);
1310abad8455SJonas Rickert   if (skipRegions)
1311abad8455SJonas Rickert     mlirOpPrintingFlagsSkipRegions(flags);
1312436c6c9cSStella Laurenzo 
1313436c6c9cSStella Laurenzo   PyFileAccumulator accum(fileObject, binary);
1314436c6c9cSStella Laurenzo   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1315436c6c9cSStella Laurenzo                               accum.getUserData());
1316436c6c9cSStella Laurenzo   mlirOpPrintingFlagsDestroy(flags);
1317436c6c9cSStella Laurenzo }
1318436c6c9cSStella Laurenzo 
1319b56d1ec6SPeter Hawkins void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
1320204acc5cSJacques Pienaar                             bool binary) {
1321204acc5cSJacques Pienaar   PyOperation &operation = getOperation();
1322204acc5cSJacques Pienaar   operation.checkValid();
1323204acc5cSJacques Pienaar   if (fileObject.is_none())
1324b56d1ec6SPeter Hawkins     fileObject = nb::module_::import_("sys").attr("stdout");
1325204acc5cSJacques Pienaar   PyFileAccumulator accum(fileObject, binary);
1326204acc5cSJacques Pienaar   mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1327204acc5cSJacques Pienaar                               accum.getUserData());
1328204acc5cSJacques Pienaar }
1329204acc5cSJacques Pienaar 
1330b56d1ec6SPeter Hawkins void PyOperationBase::writeBytecode(const nb::object &fileObject,
13310610e2f6SJacques Pienaar                                     std::optional<int64_t> bytecodeVersion) {
133289418ddcSMehdi Amini   PyOperation &operation = getOperation();
133389418ddcSMehdi Amini   operation.checkValid();
133489418ddcSMehdi Amini   PyFileAccumulator accum(fileObject, /*binary=*/true);
13350610e2f6SJacques Pienaar 
13360610e2f6SJacques Pienaar   if (!bytecodeVersion.has_value())
13370610e2f6SJacques Pienaar     return mlirOperationWriteBytecode(operation, accum.getCallback(),
133889418ddcSMehdi Amini                                       accum.getUserData());
13390610e2f6SJacques Pienaar 
13400610e2f6SJacques Pienaar   MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
13410610e2f6SJacques Pienaar   mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
13425c90e1ffSJacques Pienaar   MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
13430610e2f6SJacques Pienaar       operation, config, accum.getCallback(), accum.getUserData());
13449816cc91SAdam Paszke   mlirBytecodeWriterConfigDestroy(config);
13455c90e1ffSJacques Pienaar   if (mlirLogicalResultIsFailure(res))
1346b56d1ec6SPeter Hawkins     throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
13475c90e1ffSJacques Pienaar                            Twine(*bytecodeVersion))
1348b56d1ec6SPeter Hawkins                               .str()
1349b56d1ec6SPeter Hawkins                               .c_str());
135089418ddcSMehdi Amini }
135189418ddcSMehdi Amini 
135247148832SHideto Ueno void PyOperationBase::walk(
135347148832SHideto Ueno     std::function<MlirWalkResult(MlirOperation)> callback,
135447148832SHideto Ueno     MlirWalkOrder walkOrder) {
135547148832SHideto Ueno   PyOperation &operation = getOperation();
135647148832SHideto Ueno   operation.checkValid();
1357bc553646Stomnatan30   struct UserData {
1358bc553646Stomnatan30     std::function<MlirWalkResult(MlirOperation)> callback;
1359bc553646Stomnatan30     bool gotException;
1360bc553646Stomnatan30     std::string exceptionWhat;
1361b56d1ec6SPeter Hawkins     nb::object exceptionType;
1362bc553646Stomnatan30   };
1363bc553646Stomnatan30   UserData userData{callback, false, {}, {}};
136447148832SHideto Ueno   MlirOperationWalkCallback walkCallback = [](MlirOperation op,
136547148832SHideto Ueno                                               void *userData) {
1366bc553646Stomnatan30     UserData *calleeUserData = static_cast<UserData *>(userData);
1367bc553646Stomnatan30     try {
1368bc553646Stomnatan30       return (calleeUserData->callback)(op);
1369b56d1ec6SPeter Hawkins     } catch (nb::python_error &e) {
1370bc553646Stomnatan30       calleeUserData->gotException = true;
1371b56d1ec6SPeter Hawkins       calleeUserData->exceptionWhat = std::string(e.what());
1372b56d1ec6SPeter Hawkins       calleeUserData->exceptionType = nb::borrow(e.type());
1373bc553646Stomnatan30       return MlirWalkResult::MlirWalkResultInterrupt;
1374bc553646Stomnatan30     }
137547148832SHideto Ueno   };
1376bc553646Stomnatan30   mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1377bc553646Stomnatan30   if (userData.gotException) {
1378bc553646Stomnatan30     std::string message("Exception raised in callback: ");
1379bc553646Stomnatan30     message.append(userData.exceptionWhat);
1380bc553646Stomnatan30     throw std::runtime_error(message);
1381bc553646Stomnatan30   }
138247148832SHideto Ueno }
138347148832SHideto Ueno 
1384b56d1ec6SPeter Hawkins nb::object PyOperationBase::getAsm(bool binary,
13850a81ace0SKazu Hirata                                    std::optional<int64_t> largeElementsLimit,
1386436c6c9cSStella Laurenzo                                    bool enableDebugInfo, bool prettyDebugInfo,
1387ace1d0adSStella Laurenzo                                    bool printGenericOpForm, bool useLocalScope,
1388abad8455SJonas Rickert                                    bool assumeVerified, bool skipRegions) {
1389b56d1ec6SPeter Hawkins   nb::object fileObject;
1390436c6c9cSStella Laurenzo   if (binary) {
1391b56d1ec6SPeter Hawkins     fileObject = nb::module_::import_("io").attr("BytesIO")();
1392436c6c9cSStella Laurenzo   } else {
1393b56d1ec6SPeter Hawkins     fileObject = nb::module_::import_("io").attr("StringIO")();
1394436c6c9cSStella Laurenzo   }
1395204acc5cSJacques Pienaar   print(/*largeElementsLimit=*/largeElementsLimit,
1396436c6c9cSStella Laurenzo         /*enableDebugInfo=*/enableDebugInfo,
1397436c6c9cSStella Laurenzo         /*prettyDebugInfo=*/prettyDebugInfo,
1398436c6c9cSStella Laurenzo         /*printGenericOpForm=*/printGenericOpForm,
1399ace1d0adSStella Laurenzo         /*useLocalScope=*/useLocalScope,
1400204acc5cSJacques Pienaar         /*assumeVerified=*/assumeVerified,
1401204acc5cSJacques Pienaar         /*fileObject=*/fileObject,
1402abad8455SJonas Rickert         /*binary=*/binary,
1403abad8455SJonas Rickert         /*skipRegions=*/skipRegions);
1404436c6c9cSStella Laurenzo 
1405436c6c9cSStella Laurenzo   return fileObject.attr("getvalue")();
1406436c6c9cSStella Laurenzo }
1407436c6c9cSStella Laurenzo 
140824685aaeSAlex Zinenko void PyOperationBase::moveAfter(PyOperationBase &other) {
140924685aaeSAlex Zinenko   PyOperation &operation = getOperation();
141024685aaeSAlex Zinenko   PyOperation &otherOp = other.getOperation();
141124685aaeSAlex Zinenko   operation.checkValid();
141224685aaeSAlex Zinenko   otherOp.checkValid();
141324685aaeSAlex Zinenko   mlirOperationMoveAfter(operation, otherOp);
141424685aaeSAlex Zinenko   operation.parentKeepAlive = otherOp.parentKeepAlive;
141524685aaeSAlex Zinenko }
141624685aaeSAlex Zinenko 
141724685aaeSAlex Zinenko void PyOperationBase::moveBefore(PyOperationBase &other) {
141824685aaeSAlex Zinenko   PyOperation &operation = getOperation();
141924685aaeSAlex Zinenko   PyOperation &otherOp = other.getOperation();
142024685aaeSAlex Zinenko   operation.checkValid();
142124685aaeSAlex Zinenko   otherOp.checkValid();
142224685aaeSAlex Zinenko   mlirOperationMoveBefore(operation, otherOp);
142324685aaeSAlex Zinenko   operation.parentKeepAlive = otherOp.parentKeepAlive;
142424685aaeSAlex Zinenko }
142524685aaeSAlex Zinenko 
14263ea4c501SRahul Kayaith bool PyOperationBase::verify() {
14273ea4c501SRahul Kayaith   PyOperation &op = getOperation();
14283ea4c501SRahul Kayaith   PyMlirContext::ErrorCapture errors(op.getContext());
14293ea4c501SRahul Kayaith   if (!mlirOperationVerify(op.get()))
14303ea4c501SRahul Kayaith     throw MLIRError("Verification failed", errors.take());
14313ea4c501SRahul Kayaith   return true;
14323ea4c501SRahul Kayaith }
14333ea4c501SRahul Kayaith 
14340a81ace0SKazu Hirata std::optional<PyOperationRef> PyOperation::getParentOperation() {
143549745f87SMike Urbach   checkValid();
1436436c6c9cSStella Laurenzo   if (!isAttached())
1437b56d1ec6SPeter Hawkins     throw nb::value_error("Detached operations have no parent");
1438436c6c9cSStella Laurenzo   MlirOperation operation = mlirOperationGetParentOperation(get());
1439436c6c9cSStella Laurenzo   if (mlirOperationIsNull(operation))
14401689dadeSJohn Demme     return {};
1441436c6c9cSStella Laurenzo   return PyOperation::forOperation(getContext(), operation);
1442436c6c9cSStella Laurenzo }
1443436c6c9cSStella Laurenzo 
1444436c6c9cSStella Laurenzo PyBlock PyOperation::getBlock() {
144549745f87SMike Urbach   checkValid();
14460a81ace0SKazu Hirata   std::optional<PyOperationRef> parentOperation = getParentOperation();
1447436c6c9cSStella Laurenzo   MlirBlock block = mlirOperationGetBlock(get());
1448436c6c9cSStella Laurenzo   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
14491689dadeSJohn Demme   assert(parentOperation && "Operation has no parent");
14501689dadeSJohn Demme   return PyBlock{std::move(*parentOperation), block};
1451436c6c9cSStella Laurenzo }
1452436c6c9cSStella Laurenzo 
1453b56d1ec6SPeter Hawkins nb::object PyOperation::getCapsule() {
145449745f87SMike Urbach   checkValid();
1455b56d1ec6SPeter Hawkins   return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
14560126e906SJohn Demme }
14570126e906SJohn Demme 
1458b56d1ec6SPeter Hawkins nb::object PyOperation::createFromCapsule(nb::object capsule) {
14590126e906SJohn Demme   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
14600126e906SJohn Demme   if (mlirOperationIsNull(rawOperation))
1461b56d1ec6SPeter Hawkins     throw nb::python_error();
14620126e906SJohn Demme   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
146378bd1246SAlex Zinenko   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
14640126e906SJohn Demme       .releaseObject();
14650126e906SJohn Demme }
14660126e906SJohn Demme 
1467774818c0SDominik Grewe static void maybeInsertOperation(PyOperationRef &op,
1468b56d1ec6SPeter Hawkins                                  const nb::object &maybeIp) {
1469774818c0SDominik Grewe   // InsertPoint active?
1470b56d1ec6SPeter Hawkins   if (!maybeIp.is(nb::cast(false))) {
1471774818c0SDominik Grewe     PyInsertionPoint *ip;
1472774818c0SDominik Grewe     if (maybeIp.is_none()) {
1473774818c0SDominik Grewe       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1474774818c0SDominik Grewe     } else {
1475b56d1ec6SPeter Hawkins       ip = nb::cast<PyInsertionPoint *>(maybeIp);
1476774818c0SDominik Grewe     }
1477774818c0SDominik Grewe     if (ip)
1478774818c0SDominik Grewe       ip->insert(*op.get());
1479774818c0SDominik Grewe   }
1480774818c0SDominik Grewe }
1481774818c0SDominik Grewe 
1482f4125e02SPeter Hawkins nb::object PyOperation::create(std::string_view name,
14830a81ace0SKazu Hirata                                std::optional<std::vector<PyType *>> results,
1484*acde3f72SPeter Hawkins                                llvm::ArrayRef<MlirValue> operands,
1485b56d1ec6SPeter Hawkins                                std::optional<nb::dict> attributes,
14860a81ace0SKazu Hirata                                std::optional<std::vector<PyBlock *>> successors,
14870a81ace0SKazu Hirata                                int regions, DefaultingPyLocation location,
1488b56d1ec6SPeter Hawkins                                const nb::object &maybeIp, bool inferType) {
1489436c6c9cSStella Laurenzo   llvm::SmallVector<MlirType, 4> mlirResults;
1490436c6c9cSStella Laurenzo   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1491436c6c9cSStella Laurenzo   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1492436c6c9cSStella Laurenzo 
1493436c6c9cSStella Laurenzo   // General parameter validation.
1494436c6c9cSStella Laurenzo   if (regions < 0)
1495b56d1ec6SPeter Hawkins     throw nb::value_error("number of regions must be >= 0");
1496436c6c9cSStella Laurenzo 
1497436c6c9cSStella Laurenzo   // Unpack/validate results.
1498436c6c9cSStella Laurenzo   if (results) {
1499436c6c9cSStella Laurenzo     mlirResults.reserve(results->size());
1500436c6c9cSStella Laurenzo     for (PyType *result : *results) {
1501436c6c9cSStella Laurenzo       // TODO: Verify result type originate from the same context.
1502436c6c9cSStella Laurenzo       if (!result)
1503b56d1ec6SPeter Hawkins         throw nb::value_error("result type cannot be None");
1504436c6c9cSStella Laurenzo       mlirResults.push_back(*result);
1505436c6c9cSStella Laurenzo     }
1506436c6c9cSStella Laurenzo   }
1507436c6c9cSStella Laurenzo   // Unpack/validate attributes.
1508436c6c9cSStella Laurenzo   if (attributes) {
1509436c6c9cSStella Laurenzo     mlirAttributes.reserve(attributes->size());
1510b56d1ec6SPeter Hawkins     for (std::pair<nb::handle, nb::handle> it : *attributes) {
1511436c6c9cSStella Laurenzo       std::string key;
1512436c6c9cSStella Laurenzo       try {
1513b56d1ec6SPeter Hawkins         key = nb::cast<std::string>(it.first);
1514b56d1ec6SPeter Hawkins       } catch (nb::cast_error &err) {
1515436c6c9cSStella Laurenzo         std::string msg = "Invalid attribute key (not a string) when "
1516436c6c9cSStella Laurenzo                           "attempting to create the operation \"" +
1517f4125e02SPeter Hawkins                           std::string(name) + "\" (" + err.what() + ")";
1518b56d1ec6SPeter Hawkins         throw nb::type_error(msg.c_str());
1519436c6c9cSStella Laurenzo       }
1520436c6c9cSStella Laurenzo       try {
1521b56d1ec6SPeter Hawkins         auto &attribute = nb::cast<PyAttribute &>(it.second);
1522436c6c9cSStella Laurenzo         // TODO: Verify attribute originates from the same context.
1523436c6c9cSStella Laurenzo         mlirAttributes.emplace_back(std::move(key), attribute);
1524b56d1ec6SPeter Hawkins       } catch (nb::cast_error &err) {
1525b56d1ec6SPeter Hawkins         std::string msg = "Invalid attribute value for the key \"" + key +
1526b56d1ec6SPeter Hawkins                           "\" when attempting to create the operation \"" +
1527f4125e02SPeter Hawkins                           std::string(name) + "\" (" + err.what() + ")";
1528b56d1ec6SPeter Hawkins         throw nb::type_error(msg.c_str());
1529b56d1ec6SPeter Hawkins       } catch (std::runtime_error &) {
1530436c6c9cSStella Laurenzo         // This exception seems thrown when the value is "None".
1531436c6c9cSStella Laurenzo         std::string msg =
1532436c6c9cSStella Laurenzo             "Found an invalid (`None`?) attribute value for the key \"" + key +
1533f4125e02SPeter Hawkins             "\" when attempting to create the operation \"" +
1534f4125e02SPeter Hawkins             std::string(name) + "\"";
1535b56d1ec6SPeter Hawkins         throw std::runtime_error(msg);
1536436c6c9cSStella Laurenzo       }
1537436c6c9cSStella Laurenzo     }
1538436c6c9cSStella Laurenzo   }
1539436c6c9cSStella Laurenzo   // Unpack/validate successors.
1540436c6c9cSStella Laurenzo   if (successors) {
1541436c6c9cSStella Laurenzo     mlirSuccessors.reserve(successors->size());
1542436c6c9cSStella Laurenzo     for (auto *successor : *successors) {
1543436c6c9cSStella Laurenzo       // TODO: Verify successor originate from the same context.
1544436c6c9cSStella Laurenzo       if (!successor)
1545b56d1ec6SPeter Hawkins         throw nb::value_error("successor block cannot be None");
1546436c6c9cSStella Laurenzo       mlirSuccessors.push_back(successor->get());
1547436c6c9cSStella Laurenzo     }
1548436c6c9cSStella Laurenzo   }
1549436c6c9cSStella Laurenzo 
1550436c6c9cSStella Laurenzo   // Apply unpacked/validated to the operation state. Beyond this
1551436c6c9cSStella Laurenzo   // point, exceptions cannot be thrown or else the state will leak.
1552436c6c9cSStella Laurenzo   MlirOperationState state =
1553436c6c9cSStella Laurenzo       mlirOperationStateGet(toMlirStringRef(name), location);
1554*acde3f72SPeter Hawkins   if (!operands.empty())
1555*acde3f72SPeter Hawkins     mlirOperationStateAddOperands(&state, operands.size(), operands.data());
1556f573bc24SJacques Pienaar   state.enableResultTypeInference = inferType;
1557436c6c9cSStella Laurenzo   if (!mlirResults.empty())
1558436c6c9cSStella Laurenzo     mlirOperationStateAddResults(&state, mlirResults.size(),
1559436c6c9cSStella Laurenzo                                  mlirResults.data());
1560436c6c9cSStella Laurenzo   if (!mlirAttributes.empty()) {
1561436c6c9cSStella Laurenzo     // Note that the attribute names directly reference bytes in
1562436c6c9cSStella Laurenzo     // mlirAttributes, so that vector must not be changed from here
1563436c6c9cSStella Laurenzo     // on.
1564436c6c9cSStella Laurenzo     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1565436c6c9cSStella Laurenzo     mlirNamedAttributes.reserve(mlirAttributes.size());
1566436c6c9cSStella Laurenzo     for (auto &it : mlirAttributes)
1567436c6c9cSStella Laurenzo       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1568436c6c9cSStella Laurenzo           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1569436c6c9cSStella Laurenzo                             toMlirStringRef(it.first)),
1570436c6c9cSStella Laurenzo           it.second));
1571436c6c9cSStella Laurenzo     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1572436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
1573436c6c9cSStella Laurenzo   }
1574436c6c9cSStella Laurenzo   if (!mlirSuccessors.empty())
1575436c6c9cSStella Laurenzo     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1576436c6c9cSStella Laurenzo                                     mlirSuccessors.data());
1577436c6c9cSStella Laurenzo   if (regions) {
1578436c6c9cSStella Laurenzo     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1579436c6c9cSStella Laurenzo     mlirRegions.resize(regions);
1580436c6c9cSStella Laurenzo     for (int i = 0; i < regions; ++i)
1581436c6c9cSStella Laurenzo       mlirRegions[i] = mlirRegionCreate();
1582436c6c9cSStella Laurenzo     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1583436c6c9cSStella Laurenzo                                       mlirRegions.data());
1584436c6c9cSStella Laurenzo   }
1585436c6c9cSStella Laurenzo 
1586436c6c9cSStella Laurenzo   // Construct the operation.
1587436c6c9cSStella Laurenzo   MlirOperation operation = mlirOperationCreate(&state);
1588f573bc24SJacques Pienaar   if (!operation.ptr)
1589b56d1ec6SPeter Hawkins     throw nb::value_error("Operation creation failed");
1590436c6c9cSStella Laurenzo   PyOperationRef created =
1591436c6c9cSStella Laurenzo       PyOperation::createDetached(location->getContext(), operation);
1592774818c0SDominik Grewe   maybeInsertOperation(created, maybeIp);
1593436c6c9cSStella Laurenzo 
159400a93e62SPeter Hawkins   return created.getObject();
1595436c6c9cSStella Laurenzo }
1596436c6c9cSStella Laurenzo 
1597b56d1ec6SPeter Hawkins nb::object PyOperation::clone(const nb::object &maybeIp) {
1598774818c0SDominik Grewe   MlirOperation clonedOperation = mlirOperationClone(operation);
1599774818c0SDominik Grewe   PyOperationRef cloned =
1600774818c0SDominik Grewe       PyOperation::createDetached(getContext(), clonedOperation);
1601774818c0SDominik Grewe   maybeInsertOperation(cloned, maybeIp);
1602774818c0SDominik Grewe 
1603774818c0SDominik Grewe   return cloned->createOpView();
1604774818c0SDominik Grewe }
1605774818c0SDominik Grewe 
1606b56d1ec6SPeter Hawkins nb::object PyOperation::createOpView() {
160749745f87SMike Urbach   checkValid();
1608436c6c9cSStella Laurenzo   MlirIdentifier ident = mlirOperationGetName(get());
1609436c6c9cSStella Laurenzo   MlirStringRef identStr = mlirIdentifierStr(ident);
1610a7f8b7cdSRahul Kayaith   auto operationCls = PyGlobals::get().lookupOperationClass(
1611436c6c9cSStella Laurenzo       StringRef(identStr.data, identStr.length));
1612a7f8b7cdSRahul Kayaith   if (operationCls)
1613b56d1ec6SPeter Hawkins     return PyOpView::constructDerived(*operationCls, getRef().getObject());
1614b56d1ec6SPeter Hawkins   return nb::cast(PyOpView(getRef().getObject()));
1615436c6c9cSStella Laurenzo }
1616436c6c9cSStella Laurenzo 
161749745f87SMike Urbach void PyOperation::erase() {
161849745f87SMike Urbach   checkValid();
161967897d77SOleksandr "Alex" Zinenko   getContext()->clearOperationAndInside(*this);
162049745f87SMike Urbach   mlirOperationDestroy(operation);
162149745f87SMike Urbach }
162249745f87SMike Urbach 
1623*acde3f72SPeter Hawkins namespace {
1624*acde3f72SPeter Hawkins /// CRTP base class for Python MLIR values that subclass Value and should be
1625*acde3f72SPeter Hawkins /// castable from it. The value hierarchy is one level deep and is not supposed
1626*acde3f72SPeter Hawkins /// to accommodate other levels unless core MLIR changes.
1627*acde3f72SPeter Hawkins template <typename DerivedTy>
1628*acde3f72SPeter Hawkins class PyConcreteValue : public PyValue {
1629*acde3f72SPeter Hawkins public:
1630*acde3f72SPeter Hawkins   // Derived classes must define statics for:
1631*acde3f72SPeter Hawkins   //   IsAFunctionTy isaFunction
1632*acde3f72SPeter Hawkins   //   const char *pyClassName
1633*acde3f72SPeter Hawkins   // and redefine bindDerived.
1634*acde3f72SPeter Hawkins   using ClassTy = nb::class_<DerivedTy, PyValue>;
1635*acde3f72SPeter Hawkins   using IsAFunctionTy = bool (*)(MlirValue);
1636*acde3f72SPeter Hawkins 
1637*acde3f72SPeter Hawkins   PyConcreteValue() = default;
1638*acde3f72SPeter Hawkins   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1639*acde3f72SPeter Hawkins       : PyValue(operationRef, value) {}
1640*acde3f72SPeter Hawkins   PyConcreteValue(PyValue &orig)
1641*acde3f72SPeter Hawkins       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1642*acde3f72SPeter Hawkins 
1643*acde3f72SPeter Hawkins   /// Attempts to cast the original value to the derived type and throws on
1644*acde3f72SPeter Hawkins   /// type mismatches.
1645*acde3f72SPeter Hawkins   static MlirValue castFrom(PyValue &orig) {
1646*acde3f72SPeter Hawkins     if (!DerivedTy::isaFunction(orig.get())) {
1647*acde3f72SPeter Hawkins       auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
1648*acde3f72SPeter Hawkins       throw nb::value_error((Twine("Cannot cast value to ") +
1649*acde3f72SPeter Hawkins                              DerivedTy::pyClassName + " (from " + origRepr +
1650*acde3f72SPeter Hawkins                              ")")
1651*acde3f72SPeter Hawkins                                 .str()
1652*acde3f72SPeter Hawkins                                 .c_str());
1653*acde3f72SPeter Hawkins     }
1654*acde3f72SPeter Hawkins     return orig.get();
1655*acde3f72SPeter Hawkins   }
1656*acde3f72SPeter Hawkins 
1657*acde3f72SPeter Hawkins   /// Binds the Python module objects to functions of this class.
1658*acde3f72SPeter Hawkins   static void bind(nb::module_ &m) {
1659*acde3f72SPeter Hawkins     auto cls = ClassTy(m, DerivedTy::pyClassName);
1660*acde3f72SPeter Hawkins     cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
1661*acde3f72SPeter Hawkins     cls.def_static(
1662*acde3f72SPeter Hawkins         "isinstance",
1663*acde3f72SPeter Hawkins         [](PyValue &otherValue) -> bool {
1664*acde3f72SPeter Hawkins           return DerivedTy::isaFunction(otherValue);
1665*acde3f72SPeter Hawkins         },
1666*acde3f72SPeter Hawkins         nb::arg("other_value"));
1667*acde3f72SPeter Hawkins     cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
1668*acde3f72SPeter Hawkins             [](DerivedTy &self) { return self.maybeDownCast(); });
1669*acde3f72SPeter Hawkins     DerivedTy::bindDerived(cls);
1670*acde3f72SPeter Hawkins   }
1671*acde3f72SPeter Hawkins 
1672*acde3f72SPeter Hawkins   /// Implemented by derived classes to add methods to the Python subclass.
1673*acde3f72SPeter Hawkins   static void bindDerived(ClassTy &m) {}
1674*acde3f72SPeter Hawkins };
1675*acde3f72SPeter Hawkins 
1676*acde3f72SPeter Hawkins } // namespace
1677*acde3f72SPeter Hawkins 
1678*acde3f72SPeter Hawkins /// Python wrapper for MlirOpResult.
1679*acde3f72SPeter Hawkins class PyOpResult : public PyConcreteValue<PyOpResult> {
1680*acde3f72SPeter Hawkins public:
1681*acde3f72SPeter Hawkins   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1682*acde3f72SPeter Hawkins   static constexpr const char *pyClassName = "OpResult";
1683*acde3f72SPeter Hawkins   using PyConcreteValue::PyConcreteValue;
1684*acde3f72SPeter Hawkins 
1685*acde3f72SPeter Hawkins   static void bindDerived(ClassTy &c) {
1686*acde3f72SPeter Hawkins     c.def_prop_ro("owner", [](PyOpResult &self) {
1687*acde3f72SPeter Hawkins       assert(
1688*acde3f72SPeter Hawkins           mlirOperationEqual(self.getParentOperation()->get(),
1689*acde3f72SPeter Hawkins                              mlirOpResultGetOwner(self.get())) &&
1690*acde3f72SPeter Hawkins           "expected the owner of the value in Python to match that in the IR");
1691*acde3f72SPeter Hawkins       return self.getParentOperation().getObject();
1692*acde3f72SPeter Hawkins     });
1693*acde3f72SPeter Hawkins     c.def_prop_ro("result_number", [](PyOpResult &self) {
1694*acde3f72SPeter Hawkins       return mlirOpResultGetResultNumber(self.get());
1695*acde3f72SPeter Hawkins     });
1696*acde3f72SPeter Hawkins   }
1697*acde3f72SPeter Hawkins };
1698*acde3f72SPeter Hawkins 
1699*acde3f72SPeter Hawkins /// Returns the list of types of the values held by container.
1700*acde3f72SPeter Hawkins template <typename Container>
1701*acde3f72SPeter Hawkins static std::vector<MlirType> getValueTypes(Container &container,
1702*acde3f72SPeter Hawkins                                            PyMlirContextRef &context) {
1703*acde3f72SPeter Hawkins   std::vector<MlirType> result;
1704*acde3f72SPeter Hawkins   result.reserve(container.size());
1705*acde3f72SPeter Hawkins   for (int i = 0, e = container.size(); i < e; ++i) {
1706*acde3f72SPeter Hawkins     result.push_back(mlirValueGetType(container.getElement(i).get()));
1707*acde3f72SPeter Hawkins   }
1708*acde3f72SPeter Hawkins   return result;
1709*acde3f72SPeter Hawkins }
1710*acde3f72SPeter Hawkins 
1711*acde3f72SPeter Hawkins /// A list of operation results. Internally, these are stored as consecutive
1712*acde3f72SPeter Hawkins /// elements, random access is cheap. The (returned) result list is associated
1713*acde3f72SPeter Hawkins /// with the operation whose results these are, and thus extends the lifetime of
1714*acde3f72SPeter Hawkins /// this operation.
1715*acde3f72SPeter Hawkins class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1716*acde3f72SPeter Hawkins public:
1717*acde3f72SPeter Hawkins   static constexpr const char *pyClassName = "OpResultList";
1718*acde3f72SPeter Hawkins   using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
1719*acde3f72SPeter Hawkins 
1720*acde3f72SPeter Hawkins   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1721*acde3f72SPeter Hawkins                  intptr_t length = -1, intptr_t step = 1)
1722*acde3f72SPeter Hawkins       : Sliceable(startIndex,
1723*acde3f72SPeter Hawkins                   length == -1 ? mlirOperationGetNumResults(operation->get())
1724*acde3f72SPeter Hawkins                                : length,
1725*acde3f72SPeter Hawkins                   step),
1726*acde3f72SPeter Hawkins         operation(std::move(operation)) {}
1727*acde3f72SPeter Hawkins 
1728*acde3f72SPeter Hawkins   static void bindDerived(ClassTy &c) {
1729*acde3f72SPeter Hawkins     c.def_prop_ro("types", [](PyOpResultList &self) {
1730*acde3f72SPeter Hawkins       return getValueTypes(self, self.operation->getContext());
1731*acde3f72SPeter Hawkins     });
1732*acde3f72SPeter Hawkins     c.def_prop_ro("owner", [](PyOpResultList &self) {
1733*acde3f72SPeter Hawkins       return self.operation->createOpView();
1734*acde3f72SPeter Hawkins     });
1735*acde3f72SPeter Hawkins   }
1736*acde3f72SPeter Hawkins 
1737*acde3f72SPeter Hawkins   PyOperationRef &getOperation() { return operation; }
1738*acde3f72SPeter Hawkins 
1739*acde3f72SPeter Hawkins private:
1740*acde3f72SPeter Hawkins   /// Give the parent CRTP class access to hook implementations below.
1741*acde3f72SPeter Hawkins   friend class Sliceable<PyOpResultList, PyOpResult>;
1742*acde3f72SPeter Hawkins 
1743*acde3f72SPeter Hawkins   intptr_t getRawNumElements() {
1744*acde3f72SPeter Hawkins     operation->checkValid();
1745*acde3f72SPeter Hawkins     return mlirOperationGetNumResults(operation->get());
1746*acde3f72SPeter Hawkins   }
1747*acde3f72SPeter Hawkins 
1748*acde3f72SPeter Hawkins   PyOpResult getRawElement(intptr_t index) {
1749*acde3f72SPeter Hawkins     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1750*acde3f72SPeter Hawkins     return PyOpResult(value);
1751*acde3f72SPeter Hawkins   }
1752*acde3f72SPeter Hawkins 
1753*acde3f72SPeter Hawkins   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1754*acde3f72SPeter Hawkins     return PyOpResultList(operation, startIndex, length, step);
1755*acde3f72SPeter Hawkins   }
1756*acde3f72SPeter Hawkins 
1757*acde3f72SPeter Hawkins   PyOperationRef operation;
1758*acde3f72SPeter Hawkins };
1759*acde3f72SPeter Hawkins 
1760436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1761436c6c9cSStella Laurenzo // PyOpView
1762436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1763436c6c9cSStella Laurenzo 
1764b56d1ec6SPeter Hawkins static void populateResultTypes(StringRef name, nb::list resultTypeList,
1765b56d1ec6SPeter Hawkins                                 const nb::object &resultSegmentSpecObj,
1766f573bc24SJacques Pienaar                                 std::vector<int32_t> &resultSegmentLengths,
1767f573bc24SJacques Pienaar                                 std::vector<PyType *> &resultTypes) {
1768436c6c9cSStella Laurenzo   resultTypes.reserve(resultTypeList.size());
1769436c6c9cSStella Laurenzo   if (resultSegmentSpecObj.is_none()) {
1770436c6c9cSStella Laurenzo     // Non-variadic result unpacking.
1771e4853be2SMehdi Amini     for (const auto &it : llvm::enumerate(resultTypeList)) {
1772436c6c9cSStella Laurenzo       try {
1773b56d1ec6SPeter Hawkins         resultTypes.push_back(nb::cast<PyType *>(it.value()));
1774436c6c9cSStella Laurenzo         if (!resultTypes.back())
1775b56d1ec6SPeter Hawkins           throw nb::cast_error();
1776b56d1ec6SPeter Hawkins       } catch (nb::cast_error &err) {
1777b56d1ec6SPeter Hawkins         throw nb::value_error((llvm::Twine("Result ") +
1778436c6c9cSStella Laurenzo                                llvm::Twine(it.index()) + " of operation \"" +
1779436c6c9cSStella Laurenzo                                name + "\" must be a Type (" + err.what() + ")")
1780b56d1ec6SPeter Hawkins                                   .str()
1781b56d1ec6SPeter Hawkins                                   .c_str());
1782436c6c9cSStella Laurenzo       }
1783436c6c9cSStella Laurenzo     }
1784436c6c9cSStella Laurenzo   } else {
1785436c6c9cSStella Laurenzo     // Sized result unpacking.
1786b56d1ec6SPeter Hawkins     auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1787436c6c9cSStella Laurenzo     if (resultSegmentSpec.size() != resultTypeList.size()) {
1788b56d1ec6SPeter Hawkins       throw nb::value_error((llvm::Twine("Operation \"") + name +
1789436c6c9cSStella Laurenzo                              "\" requires " +
1790436c6c9cSStella Laurenzo                              llvm::Twine(resultSegmentSpec.size()) +
1791436c6c9cSStella Laurenzo                              " result segments but was provided " +
1792436c6c9cSStella Laurenzo                              llvm::Twine(resultTypeList.size()))
1793b56d1ec6SPeter Hawkins                                 .str()
1794b56d1ec6SPeter Hawkins                                 .c_str());
1795436c6c9cSStella Laurenzo     }
1796436c6c9cSStella Laurenzo     resultSegmentLengths.reserve(resultTypeList.size());
1797e4853be2SMehdi Amini     for (const auto &it :
1798436c6c9cSStella Laurenzo          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1799436c6c9cSStella Laurenzo       int segmentSpec = std::get<1>(it.value());
1800436c6c9cSStella Laurenzo       if (segmentSpec == 1 || segmentSpec == 0) {
1801436c6c9cSStella Laurenzo         // Unpack unary element.
1802436c6c9cSStella Laurenzo         try {
1803b56d1ec6SPeter Hawkins           auto *resultType = nb::cast<PyType *>(std::get<0>(it.value()));
1804436c6c9cSStella Laurenzo           if (resultType) {
1805436c6c9cSStella Laurenzo             resultTypes.push_back(resultType);
1806436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(1);
1807436c6c9cSStella Laurenzo           } else if (segmentSpec == 0) {
1808436c6c9cSStella Laurenzo             // Allowed to be optional.
1809436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(0);
1810436c6c9cSStella Laurenzo           } else {
1811b56d1ec6SPeter Hawkins             throw nb::value_error(
1812b56d1ec6SPeter Hawkins                 (llvm::Twine("Result ") + llvm::Twine(it.index()) +
1813b56d1ec6SPeter Hawkins                  " of operation \"" + name +
1814b56d1ec6SPeter Hawkins                  "\" must be a Type (was None and result is not optional)")
1815b56d1ec6SPeter Hawkins                     .str()
1816b56d1ec6SPeter Hawkins                     .c_str());
1817436c6c9cSStella Laurenzo           }
1818b56d1ec6SPeter Hawkins         } catch (nb::cast_error &err) {
1819b56d1ec6SPeter Hawkins           throw nb::value_error((llvm::Twine("Result ") +
1820436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1821436c6c9cSStella Laurenzo                                  name + "\" must be a Type (" + err.what() +
1822436c6c9cSStella Laurenzo                                  ")")
1823b56d1ec6SPeter Hawkins                                     .str()
1824b56d1ec6SPeter Hawkins                                     .c_str());
1825436c6c9cSStella Laurenzo         }
1826436c6c9cSStella Laurenzo       } else if (segmentSpec == -1) {
1827436c6c9cSStella Laurenzo         // Unpack sequence by appending.
1828436c6c9cSStella Laurenzo         try {
1829436c6c9cSStella Laurenzo           if (std::get<0>(it.value()).is_none()) {
1830436c6c9cSStella Laurenzo             // Treat it as an empty list.
1831436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(0);
1832436c6c9cSStella Laurenzo           } else {
1833436c6c9cSStella Laurenzo             // Unpack the list.
1834b56d1ec6SPeter Hawkins             auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1835b56d1ec6SPeter Hawkins             for (nb::handle segmentItem : segment) {
1836b56d1ec6SPeter Hawkins               resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1837436c6c9cSStella Laurenzo               if (!resultTypes.back()) {
1838b56d1ec6SPeter Hawkins                 throw nb::type_error("contained a None item");
1839436c6c9cSStella Laurenzo               }
1840436c6c9cSStella Laurenzo             }
1841b56d1ec6SPeter Hawkins             resultSegmentLengths.push_back(nb::len(segment));
1842436c6c9cSStella Laurenzo           }
1843436c6c9cSStella Laurenzo         } catch (std::exception &err) {
1844436c6c9cSStella Laurenzo           // NOTE: Sloppy to be using a catch-all here, but there are at least
1845436c6c9cSStella Laurenzo           // three different unrelated exceptions that can be thrown in the
1846436c6c9cSStella Laurenzo           // above "casts". Just keep the scope above small and catch them all.
1847b56d1ec6SPeter Hawkins           throw nb::value_error((llvm::Twine("Result ") +
1848436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1849436c6c9cSStella Laurenzo                                  name + "\" must be a Sequence of Types (" +
1850436c6c9cSStella Laurenzo                                  err.what() + ")")
1851b56d1ec6SPeter Hawkins                                     .str()
1852b56d1ec6SPeter Hawkins                                     .c_str());
1853436c6c9cSStella Laurenzo         }
1854436c6c9cSStella Laurenzo       } else {
1855b56d1ec6SPeter Hawkins         throw nb::value_error("Unexpected segment spec");
1856436c6c9cSStella Laurenzo       }
1857436c6c9cSStella Laurenzo     }
1858436c6c9cSStella Laurenzo   }
1859f573bc24SJacques Pienaar }
1860f573bc24SJacques Pienaar 
1861*acde3f72SPeter Hawkins static MlirValue getUniqueResult(MlirOperation operation) {
1862*acde3f72SPeter Hawkins   auto numResults = mlirOperationGetNumResults(operation);
1863*acde3f72SPeter Hawkins   if (numResults != 1) {
1864*acde3f72SPeter Hawkins     auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1865*acde3f72SPeter Hawkins     throw nb::value_error((Twine("Cannot call .result on operation ") +
1866*acde3f72SPeter Hawkins                            StringRef(name.data, name.length) + " which has " +
1867*acde3f72SPeter Hawkins                            Twine(numResults) +
1868*acde3f72SPeter Hawkins                            " results (it is only valid for operations with a "
1869*acde3f72SPeter Hawkins                            "single result)")
1870*acde3f72SPeter Hawkins                               .str()
1871*acde3f72SPeter Hawkins                               .c_str());
1872*acde3f72SPeter Hawkins   }
1873*acde3f72SPeter Hawkins   return mlirOperationGetResult(operation, 0);
1874*acde3f72SPeter Hawkins }
1875*acde3f72SPeter Hawkins 
1876*acde3f72SPeter Hawkins static MlirValue getOpResultOrValue(nb::handle operand) {
1877*acde3f72SPeter Hawkins   if (operand.is_none()) {
1878*acde3f72SPeter Hawkins     throw nb::value_error("contained a None item");
1879*acde3f72SPeter Hawkins   }
1880*acde3f72SPeter Hawkins   PyOperationBase *op;
1881*acde3f72SPeter Hawkins   if (nb::try_cast<PyOperationBase *>(operand, op)) {
1882*acde3f72SPeter Hawkins     return getUniqueResult(op->getOperation());
1883*acde3f72SPeter Hawkins   }
1884*acde3f72SPeter Hawkins   PyOpResultList *opResultList;
1885*acde3f72SPeter Hawkins   if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
1886*acde3f72SPeter Hawkins     return getUniqueResult(opResultList->getOperation()->get());
1887*acde3f72SPeter Hawkins   }
1888*acde3f72SPeter Hawkins   PyValue *value;
1889*acde3f72SPeter Hawkins   if (nb::try_cast<PyValue *>(operand, value)) {
1890*acde3f72SPeter Hawkins     return value->get();
1891*acde3f72SPeter Hawkins   }
1892*acde3f72SPeter Hawkins   throw nb::value_error("is not a Value");
1893*acde3f72SPeter Hawkins }
1894*acde3f72SPeter Hawkins 
1895b56d1ec6SPeter Hawkins nb::object PyOpView::buildGeneric(
1896f4125e02SPeter Hawkins     std::string_view name, std::tuple<int, bool> opRegionSpec,
1897f4125e02SPeter Hawkins     nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1898f4125e02SPeter Hawkins     std::optional<nb::list> resultTypeList, nb::list operandList,
1899f4125e02SPeter Hawkins     std::optional<nb::dict> attributes,
1900f573bc24SJacques Pienaar     std::optional<std::vector<PyBlock *>> successors,
1901f573bc24SJacques Pienaar     std::optional<int> regions, DefaultingPyLocation location,
1902b56d1ec6SPeter Hawkins     const nb::object &maybeIp) {
1903f573bc24SJacques Pienaar   PyMlirContextRef context = location->getContext();
1904f4125e02SPeter Hawkins 
1905f573bc24SJacques Pienaar   // Class level operation construction metadata.
1906f573bc24SJacques Pienaar   // Operand and result segment specs are either none, which does no
1907f573bc24SJacques Pienaar   // variadic unpacking, or a list of ints with segment sizes, where each
1908f573bc24SJacques Pienaar   // element is either a positive number (typically 1 for a scalar) or -1 to
1909f573bc24SJacques Pienaar   // indicate that it is derived from the length of the same-indexed operand
1910f573bc24SJacques Pienaar   // or result (implying that it is a list at that position).
1911f573bc24SJacques Pienaar   std::vector<int32_t> operandSegmentLengths;
1912f573bc24SJacques Pienaar   std::vector<int32_t> resultSegmentLengths;
1913f573bc24SJacques Pienaar 
1914f573bc24SJacques Pienaar   // Validate/determine region count.
1915f573bc24SJacques Pienaar   int opMinRegionCount = std::get<0>(opRegionSpec);
1916f573bc24SJacques Pienaar   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1917f573bc24SJacques Pienaar   if (!regions) {
1918f573bc24SJacques Pienaar     regions = opMinRegionCount;
1919f573bc24SJacques Pienaar   }
1920f573bc24SJacques Pienaar   if (*regions < opMinRegionCount) {
1921b56d1ec6SPeter Hawkins     throw nb::value_error(
1922f573bc24SJacques Pienaar         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1923f573bc24SJacques Pienaar          llvm::Twine(opMinRegionCount) +
1924f573bc24SJacques Pienaar          " regions but was built with regions=" + llvm::Twine(*regions))
1925b56d1ec6SPeter Hawkins             .str()
1926b56d1ec6SPeter Hawkins             .c_str());
1927f573bc24SJacques Pienaar   }
1928f573bc24SJacques Pienaar   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1929b56d1ec6SPeter Hawkins     throw nb::value_error(
1930f573bc24SJacques Pienaar         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1931f573bc24SJacques Pienaar          llvm::Twine(opMinRegionCount) +
1932f573bc24SJacques Pienaar          " regions but was built with regions=" + llvm::Twine(*regions))
1933b56d1ec6SPeter Hawkins             .str()
1934b56d1ec6SPeter Hawkins             .c_str());
1935f573bc24SJacques Pienaar   }
1936f573bc24SJacques Pienaar 
1937f573bc24SJacques Pienaar   // Unpack results.
1938f573bc24SJacques Pienaar   std::vector<PyType *> resultTypes;
1939f573bc24SJacques Pienaar   if (resultTypeList.has_value()) {
1940f573bc24SJacques Pienaar     populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1941f573bc24SJacques Pienaar                         resultSegmentLengths, resultTypes);
1942f573bc24SJacques Pienaar   }
1943436c6c9cSStella Laurenzo 
1944436c6c9cSStella Laurenzo   // Unpack operands.
1945*acde3f72SPeter Hawkins   llvm::SmallVector<MlirValue, 4> operands;
1946436c6c9cSStella Laurenzo   operands.reserve(operands.size());
1947436c6c9cSStella Laurenzo   if (operandSegmentSpecObj.is_none()) {
1948436c6c9cSStella Laurenzo     // Non-sized operand unpacking.
1949e4853be2SMehdi Amini     for (const auto &it : llvm::enumerate(operandList)) {
1950436c6c9cSStella Laurenzo       try {
1951*acde3f72SPeter Hawkins         operands.push_back(getOpResultOrValue(it.value()));
1952*acde3f72SPeter Hawkins       } catch (nb::builtin_exception &err) {
1953b56d1ec6SPeter Hawkins         throw nb::value_error((llvm::Twine("Operand ") +
1954436c6c9cSStella Laurenzo                                llvm::Twine(it.index()) + " of operation \"" +
1955436c6c9cSStella Laurenzo                                name + "\" must be a Value (" + err.what() + ")")
1956b56d1ec6SPeter Hawkins                                   .str()
1957b56d1ec6SPeter Hawkins                                   .c_str());
1958436c6c9cSStella Laurenzo       }
1959436c6c9cSStella Laurenzo     }
1960436c6c9cSStella Laurenzo   } else {
1961436c6c9cSStella Laurenzo     // Sized operand unpacking.
1962b56d1ec6SPeter Hawkins     auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1963436c6c9cSStella Laurenzo     if (operandSegmentSpec.size() != operandList.size()) {
1964b56d1ec6SPeter Hawkins       throw nb::value_error((llvm::Twine("Operation \"") + name +
1965436c6c9cSStella Laurenzo                              "\" requires " +
1966436c6c9cSStella Laurenzo                              llvm::Twine(operandSegmentSpec.size()) +
1967436c6c9cSStella Laurenzo                              "operand segments but was provided " +
1968436c6c9cSStella Laurenzo                              llvm::Twine(operandList.size()))
1969b56d1ec6SPeter Hawkins                                 .str()
1970b56d1ec6SPeter Hawkins                                 .c_str());
1971436c6c9cSStella Laurenzo     }
1972436c6c9cSStella Laurenzo     operandSegmentLengths.reserve(operandList.size());
1973e4853be2SMehdi Amini     for (const auto &it :
1974436c6c9cSStella Laurenzo          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1975436c6c9cSStella Laurenzo       int segmentSpec = std::get<1>(it.value());
1976436c6c9cSStella Laurenzo       if (segmentSpec == 1 || segmentSpec == 0) {
1977436c6c9cSStella Laurenzo         // Unpack unary element.
1978*acde3f72SPeter Hawkins         auto &operand = std::get<0>(it.value());
1979*acde3f72SPeter Hawkins         if (!operand.is_none()) {
1980436c6c9cSStella Laurenzo           try {
1981*acde3f72SPeter Hawkins 
1982*acde3f72SPeter Hawkins             operands.push_back(getOpResultOrValue(operand));
1983*acde3f72SPeter Hawkins           } catch (nb::builtin_exception &err) {
1984*acde3f72SPeter Hawkins             throw nb::value_error((llvm::Twine("Operand ") +
1985*acde3f72SPeter Hawkins                                    llvm::Twine(it.index()) +
1986*acde3f72SPeter Hawkins                                    " of operation \"" + name +
1987*acde3f72SPeter Hawkins                                    "\" must be a Value (" + err.what() + ")")
1988*acde3f72SPeter Hawkins                                       .str()
1989*acde3f72SPeter Hawkins                                       .c_str());
1990*acde3f72SPeter Hawkins           }
1991*acde3f72SPeter Hawkins 
1992436c6c9cSStella Laurenzo           operandSegmentLengths.push_back(1);
1993436c6c9cSStella Laurenzo         } else if (segmentSpec == 0) {
1994436c6c9cSStella Laurenzo           // Allowed to be optional.
1995436c6c9cSStella Laurenzo           operandSegmentLengths.push_back(0);
1996436c6c9cSStella Laurenzo         } else {
1997b56d1ec6SPeter Hawkins           throw nb::value_error(
1998b56d1ec6SPeter Hawkins               (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
1999b56d1ec6SPeter Hawkins                " of operation \"" + name +
2000b56d1ec6SPeter Hawkins                "\" must be a Value (was None and operand is not optional)")
2001b56d1ec6SPeter Hawkins                   .str()
2002b56d1ec6SPeter Hawkins                   .c_str());
2003436c6c9cSStella Laurenzo         }
2004436c6c9cSStella Laurenzo       } else if (segmentSpec == -1) {
2005436c6c9cSStella Laurenzo         // Unpack sequence by appending.
2006436c6c9cSStella Laurenzo         try {
2007436c6c9cSStella Laurenzo           if (std::get<0>(it.value()).is_none()) {
2008436c6c9cSStella Laurenzo             // Treat it as an empty list.
2009436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(0);
2010436c6c9cSStella Laurenzo           } else {
2011436c6c9cSStella Laurenzo             // Unpack the list.
2012b56d1ec6SPeter Hawkins             auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
2013b56d1ec6SPeter Hawkins             for (nb::handle segmentItem : segment) {
2014*acde3f72SPeter Hawkins               operands.push_back(getOpResultOrValue(segmentItem));
2015436c6c9cSStella Laurenzo             }
2016b56d1ec6SPeter Hawkins             operandSegmentLengths.push_back(nb::len(segment));
2017436c6c9cSStella Laurenzo           }
2018436c6c9cSStella Laurenzo         } catch (std::exception &err) {
2019436c6c9cSStella Laurenzo           // NOTE: Sloppy to be using a catch-all here, but there are at least
2020436c6c9cSStella Laurenzo           // three different unrelated exceptions that can be thrown in the
2021436c6c9cSStella Laurenzo           // above "casts". Just keep the scope above small and catch them all.
2022b56d1ec6SPeter Hawkins           throw nb::value_error((llvm::Twine("Operand ") +
2023436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
2024436c6c9cSStella Laurenzo                                  name + "\" must be a Sequence of Values (" +
2025436c6c9cSStella Laurenzo                                  err.what() + ")")
2026b56d1ec6SPeter Hawkins                                     .str()
2027b56d1ec6SPeter Hawkins                                     .c_str());
2028436c6c9cSStella Laurenzo         }
2029436c6c9cSStella Laurenzo       } else {
2030b56d1ec6SPeter Hawkins         throw nb::value_error("Unexpected segment spec");
2031436c6c9cSStella Laurenzo       }
2032436c6c9cSStella Laurenzo     }
2033436c6c9cSStella Laurenzo   }
2034436c6c9cSStella Laurenzo 
2035436c6c9cSStella Laurenzo   // Merge operand/result segment lengths into attributes if needed.
2036436c6c9cSStella Laurenzo   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
2037436c6c9cSStella Laurenzo     // Dup.
2038436c6c9cSStella Laurenzo     if (attributes) {
2039b56d1ec6SPeter Hawkins       attributes = nb::dict(*attributes);
2040436c6c9cSStella Laurenzo     } else {
2041b56d1ec6SPeter Hawkins       attributes = nb::dict();
2042436c6c9cSStella Laurenzo     }
2043363b6559SMehdi Amini     if (attributes->contains("resultSegmentSizes") ||
2044363b6559SMehdi Amini         attributes->contains("operandSegmentSizes")) {
2045b56d1ec6SPeter Hawkins       throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
2046363b6559SMehdi Amini                             "'operandSegmentSizes' attribute is unsupported. "
2047436c6c9cSStella Laurenzo                             "Use Operation.create for such low-level access.");
2048436c6c9cSStella Laurenzo     }
2049436c6c9cSStella Laurenzo 
2050363b6559SMehdi Amini     // Add resultSegmentSizes attribute.
2051436c6c9cSStella Laurenzo     if (!resultSegmentLengths.empty()) {
205258a47508SJeff Niu       MlirAttribute segmentLengthAttr =
205358a47508SJeff Niu           mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
205458a47508SJeff Niu                                resultSegmentLengths.data());
2055363b6559SMehdi Amini       (*attributes)["resultSegmentSizes"] =
2056436c6c9cSStella Laurenzo           PyAttribute(context, segmentLengthAttr);
2057436c6c9cSStella Laurenzo     }
2058436c6c9cSStella Laurenzo 
2059363b6559SMehdi Amini     // Add operandSegmentSizes attribute.
2060436c6c9cSStella Laurenzo     if (!operandSegmentLengths.empty()) {
206158a47508SJeff Niu       MlirAttribute segmentLengthAttr =
206258a47508SJeff Niu           mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
206358a47508SJeff Niu                                operandSegmentLengths.data());
2064363b6559SMehdi Amini       (*attributes)["operandSegmentSizes"] =
2065436c6c9cSStella Laurenzo           PyAttribute(context, segmentLengthAttr);
2066436c6c9cSStella Laurenzo     }
2067436c6c9cSStella Laurenzo   }
2068436c6c9cSStella Laurenzo 
2069436c6c9cSStella Laurenzo   // Delegate to create.
2070337c937dSMehdi Amini   return PyOperation::create(name,
2071436c6c9cSStella Laurenzo                              /*results=*/std::move(resultTypes),
2072436c6c9cSStella Laurenzo                              /*operands=*/std::move(operands),
2073436c6c9cSStella Laurenzo                              /*attributes=*/std::move(attributes),
2074436c6c9cSStella Laurenzo                              /*successors=*/std::move(successors),
2075f573bc24SJacques Pienaar                              /*regions=*/*regions, location, maybeIp,
2076f573bc24SJacques Pienaar                              !resultTypeList);
2077436c6c9cSStella Laurenzo }
2078436c6c9cSStella Laurenzo 
2079b56d1ec6SPeter Hawkins nb::object PyOpView::constructDerived(const nb::object &cls,
2080b56d1ec6SPeter Hawkins                                       const nb::object &operation) {
2081b56d1ec6SPeter Hawkins   nb::handle opViewType = nb::type<PyOpView>();
2082b56d1ec6SPeter Hawkins   nb::object instance = cls.attr("__new__")(cls);
2083a7f8b7cdSRahul Kayaith   opViewType.attr("__init__")(instance, operation);
2084a7f8b7cdSRahul Kayaith   return instance;
2085a7f8b7cdSRahul Kayaith }
2086a7f8b7cdSRahul Kayaith 
2087b56d1ec6SPeter Hawkins PyOpView::PyOpView(const nb::object &operationObject)
2088436c6c9cSStella Laurenzo     // Casting through the PyOperationBase base-class and then back to the
2089436c6c9cSStella Laurenzo     // Operation lets us accept any PyOperationBase subclass.
2090b56d1ec6SPeter Hawkins     : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
2091436c6c9cSStella Laurenzo       operationObject(operation.getRef().getObject()) {}
2092436c6c9cSStella Laurenzo 
2093436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
2094436c6c9cSStella Laurenzo // PyInsertionPoint.
2095436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
2096436c6c9cSStella Laurenzo 
2097436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
2098436c6c9cSStella Laurenzo 
2099436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
2100436c6c9cSStella Laurenzo     : refOperation(beforeOperationBase.getOperation().getRef()),
2101436c6c9cSStella Laurenzo       block((*refOperation)->getBlock()) {}
2102436c6c9cSStella Laurenzo 
2103436c6c9cSStella Laurenzo void PyInsertionPoint::insert(PyOperationBase &operationBase) {
2104436c6c9cSStella Laurenzo   PyOperation &operation = operationBase.getOperation();
2105436c6c9cSStella Laurenzo   if (operation.isAttached())
2106b56d1ec6SPeter Hawkins     throw nb::value_error(
2107436c6c9cSStella Laurenzo         "Attempt to insert operation that is already attached");
2108436c6c9cSStella Laurenzo   block.getParentOperation()->checkValid();
2109436c6c9cSStella Laurenzo   MlirOperation beforeOp = {nullptr};
2110436c6c9cSStella Laurenzo   if (refOperation) {
2111436c6c9cSStella Laurenzo     // Insert before operation.
2112436c6c9cSStella Laurenzo     (*refOperation)->checkValid();
2113436c6c9cSStella Laurenzo     beforeOp = (*refOperation)->get();
2114436c6c9cSStella Laurenzo   } else {
2115436c6c9cSStella Laurenzo     // Insert at end (before null) is only valid if the block does not
2116436c6c9cSStella Laurenzo     // already end in a known terminator (violating this will cause assertion
2117436c6c9cSStella Laurenzo     // failures later).
2118436c6c9cSStella Laurenzo     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
2119b56d1ec6SPeter Hawkins       throw nb::index_error("Cannot insert operation at the end of a block "
2120436c6c9cSStella Laurenzo                             "that already has a terminator. Did you mean to "
2121436c6c9cSStella Laurenzo                             "use 'InsertionPoint.at_block_terminator(block)' "
2122436c6c9cSStella Laurenzo                             "versus 'InsertionPoint(block)'?");
2123436c6c9cSStella Laurenzo     }
2124436c6c9cSStella Laurenzo   }
2125436c6c9cSStella Laurenzo   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
2126436c6c9cSStella Laurenzo   operation.setAttached();
2127436c6c9cSStella Laurenzo }
2128436c6c9cSStella Laurenzo 
2129436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
2130436c6c9cSStella Laurenzo   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
2131436c6c9cSStella Laurenzo   if (mlirOperationIsNull(firstOp)) {
2132436c6c9cSStella Laurenzo     // Just insert at end.
2133436c6c9cSStella Laurenzo     return PyInsertionPoint(block);
2134436c6c9cSStella Laurenzo   }
2135436c6c9cSStella Laurenzo 
2136436c6c9cSStella Laurenzo   // Insert before first op.
2137436c6c9cSStella Laurenzo   PyOperationRef firstOpRef = PyOperation::forOperation(
2138436c6c9cSStella Laurenzo       block.getParentOperation()->getContext(), firstOp);
2139436c6c9cSStella Laurenzo   return PyInsertionPoint{block, std::move(firstOpRef)};
2140436c6c9cSStella Laurenzo }
2141436c6c9cSStella Laurenzo 
2142436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
2143436c6c9cSStella Laurenzo   MlirOperation terminator = mlirBlockGetTerminator(block.get());
2144436c6c9cSStella Laurenzo   if (mlirOperationIsNull(terminator))
2145b56d1ec6SPeter Hawkins     throw nb::value_error("Block has no terminator");
2146436c6c9cSStella Laurenzo   PyOperationRef terminatorOpRef = PyOperation::forOperation(
2147436c6c9cSStella Laurenzo       block.getParentOperation()->getContext(), terminator);
2148436c6c9cSStella Laurenzo   return PyInsertionPoint{block, std::move(terminatorOpRef)};
2149436c6c9cSStella Laurenzo }
2150436c6c9cSStella Laurenzo 
2151b56d1ec6SPeter Hawkins nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
2152b56d1ec6SPeter Hawkins   return PyThreadContextEntry::pushInsertionPoint(insertPoint);
2153436c6c9cSStella Laurenzo }
2154436c6c9cSStella Laurenzo 
2155b56d1ec6SPeter Hawkins void PyInsertionPoint::contextExit(const nb::object &excType,
2156b56d1ec6SPeter Hawkins                                    const nb::object &excVal,
2157b56d1ec6SPeter Hawkins                                    const nb::object &excTb) {
2158436c6c9cSStella Laurenzo   PyThreadContextEntry::popInsertionPoint(*this);
2159436c6c9cSStella Laurenzo }
2160436c6c9cSStella Laurenzo 
2161436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
2162436c6c9cSStella Laurenzo // PyAttribute.
2163436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
2164436c6c9cSStella Laurenzo 
2165e6d738e0SRahul Kayaith bool PyAttribute::operator==(const PyAttribute &other) const {
2166436c6c9cSStella Laurenzo   return mlirAttributeEqual(attr, other.attr);
2167436c6c9cSStella Laurenzo }
2168436c6c9cSStella Laurenzo 
2169b56d1ec6SPeter Hawkins nb::object PyAttribute::getCapsule() {
2170b56d1ec6SPeter Hawkins   return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
2171436c6c9cSStella Laurenzo }
2172436c6c9cSStella Laurenzo 
2173b56d1ec6SPeter Hawkins PyAttribute PyAttribute::createFromCapsule(nb::object capsule) {
2174436c6c9cSStella Laurenzo   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
2175436c6c9cSStella Laurenzo   if (mlirAttributeIsNull(rawAttr))
2176b56d1ec6SPeter Hawkins     throw nb::python_error();
2177436c6c9cSStella Laurenzo   return PyAttribute(
2178436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
2179436c6c9cSStella Laurenzo }
2180436c6c9cSStella Laurenzo 
2181436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
2182436c6c9cSStella Laurenzo // PyNamedAttribute.
2183436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
2184436c6c9cSStella Laurenzo 
2185436c6c9cSStella Laurenzo PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
2186436c6c9cSStella Laurenzo     : ownedName(new std::string(std::move(ownedName))) {
2187436c6c9cSStella Laurenzo   namedAttr = mlirNamedAttributeGet(
2188436c6c9cSStella Laurenzo       mlirIdentifierGet(mlirAttributeGetContext(attr),
2189436c6c9cSStella Laurenzo                         toMlirStringRef(*this->ownedName)),
2190436c6c9cSStella Laurenzo       attr);
2191436c6c9cSStella Laurenzo }
2192436c6c9cSStella Laurenzo 
2193436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
2194436c6c9cSStella Laurenzo // PyType.
2195436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
2196436c6c9cSStella Laurenzo 
2197e6d738e0SRahul Kayaith bool PyType::operator==(const PyType &other) const {
2198436c6c9cSStella Laurenzo   return mlirTypeEqual(type, other.type);
2199436c6c9cSStella Laurenzo }
2200436c6c9cSStella Laurenzo 
2201b56d1ec6SPeter Hawkins nb::object PyType::getCapsule() {
2202b56d1ec6SPeter Hawkins   return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
2203436c6c9cSStella Laurenzo }
2204436c6c9cSStella Laurenzo 
2205b56d1ec6SPeter Hawkins PyType PyType::createFromCapsule(nb::object capsule) {
2206436c6c9cSStella Laurenzo   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
2207436c6c9cSStella Laurenzo   if (mlirTypeIsNull(rawType))
2208b56d1ec6SPeter Hawkins     throw nb::python_error();
2209436c6c9cSStella Laurenzo   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
2210436c6c9cSStella Laurenzo                 rawType);
2211436c6c9cSStella Laurenzo }
2212436c6c9cSStella Laurenzo 
2213436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
2214d39a7844Smax // PyTypeID.
2215d39a7844Smax //------------------------------------------------------------------------------
2216d39a7844Smax 
2217b56d1ec6SPeter Hawkins nb::object PyTypeID::getCapsule() {
2218b56d1ec6SPeter Hawkins   return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
2219d39a7844Smax }
2220d39a7844Smax 
2221b56d1ec6SPeter Hawkins PyTypeID PyTypeID::createFromCapsule(nb::object capsule) {
2222d39a7844Smax   MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
2223d39a7844Smax   if (mlirTypeIDIsNull(mlirTypeID))
2224b56d1ec6SPeter Hawkins     throw nb::python_error();
2225d39a7844Smax   return PyTypeID(mlirTypeID);
2226d39a7844Smax }
2227d39a7844Smax bool PyTypeID::operator==(const PyTypeID &other) const {
2228d39a7844Smax   return mlirTypeIDEqual(typeID, other.typeID);
2229d39a7844Smax }
2230d39a7844Smax 
2231d39a7844Smax //------------------------------------------------------------------------------
22327c850867SMaksim Levental // PyValue and subclasses.
2233436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
2234436c6c9cSStella Laurenzo 
2235b56d1ec6SPeter Hawkins nb::object PyValue::getCapsule() {
2236b56d1ec6SPeter Hawkins   return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
22373f3d1c90SMike Urbach }
22383f3d1c90SMike Urbach 
2239b56d1ec6SPeter Hawkins nb::object PyValue::maybeDownCast() {
22407c850867SMaksim Levental   MlirType type = mlirValueGetType(get());
22417c850867SMaksim Levental   MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
22427c850867SMaksim Levental   assert(!mlirTypeIDIsNull(mlirTypeID) &&
22437c850867SMaksim Levental          "mlirTypeID was expected to be non-null.");
2244b56d1ec6SPeter Hawkins   std::optional<nb::callable> valueCaster =
22457c850867SMaksim Levental       PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
2246b56d1ec6SPeter Hawkins   // nb::rv_policy::move means use std::move to move the return value
22477c850867SMaksim Levental   // contents into a new instance that will be owned by Python.
2248b56d1ec6SPeter Hawkins   nb::object thisObj = nb::cast(this, nb::rv_policy::move);
22497c850867SMaksim Levental   if (!valueCaster)
22507c850867SMaksim Levental     return thisObj;
22517c850867SMaksim Levental   return valueCaster.value()(thisObj);
22527c850867SMaksim Levental }
22537c850867SMaksim Levental 
2254b56d1ec6SPeter Hawkins PyValue PyValue::createFromCapsule(nb::object capsule) {
22553f3d1c90SMike Urbach   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
22563f3d1c90SMike Urbach   if (mlirValueIsNull(value))
2257b56d1ec6SPeter Hawkins     throw nb::python_error();
22583f3d1c90SMike Urbach   MlirOperation owner;
22593f3d1c90SMike Urbach   if (mlirValueIsAOpResult(value))
22603f3d1c90SMike Urbach     owner = mlirOpResultGetOwner(value);
22613f3d1c90SMike Urbach   if (mlirValueIsABlockArgument(value))
22623f3d1c90SMike Urbach     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
22633f3d1c90SMike Urbach   if (mlirOperationIsNull(owner))
2264b56d1ec6SPeter Hawkins     throw nb::python_error();
22653f3d1c90SMike Urbach   MlirContext ctx = mlirOperationGetContext(owner);
22663f3d1c90SMike Urbach   PyOperationRef ownerRef =
22673f3d1c90SMike Urbach       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
22683f3d1c90SMike Urbach   return PyValue(ownerRef, value);
22693f3d1c90SMike Urbach }
22703f3d1c90SMike Urbach 
227130d61893SAlex Zinenko //------------------------------------------------------------------------------
227230d61893SAlex Zinenko // PySymbolTable.
227330d61893SAlex Zinenko //------------------------------------------------------------------------------
227430d61893SAlex Zinenko 
227530d61893SAlex Zinenko PySymbolTable::PySymbolTable(PyOperationBase &operation)
227630d61893SAlex Zinenko     : operation(operation.getOperation().getRef()) {
227730d61893SAlex Zinenko   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
227830d61893SAlex Zinenko   if (mlirSymbolTableIsNull(symbolTable)) {
2279b56d1ec6SPeter Hawkins     throw nb::type_error("Operation is not a Symbol Table.");
228030d61893SAlex Zinenko   }
228130d61893SAlex Zinenko }
228230d61893SAlex Zinenko 
2283b56d1ec6SPeter Hawkins nb::object PySymbolTable::dunderGetItem(const std::string &name) {
228430d61893SAlex Zinenko   operation->checkValid();
228530d61893SAlex Zinenko   MlirOperation symbol = mlirSymbolTableLookup(
228630d61893SAlex Zinenko       symbolTable, mlirStringRefCreate(name.data(), name.length()));
228730d61893SAlex Zinenko   if (mlirOperationIsNull(symbol))
2288b56d1ec6SPeter Hawkins     throw nb::key_error(
2289b56d1ec6SPeter Hawkins         ("Symbol '" + name + "' not in the symbol table.").c_str());
229030d61893SAlex Zinenko 
229130d61893SAlex Zinenko   return PyOperation::forOperation(operation->getContext(), symbol,
229230d61893SAlex Zinenko                                    operation.getObject())
229330d61893SAlex Zinenko       ->createOpView();
229430d61893SAlex Zinenko }
229530d61893SAlex Zinenko 
229630d61893SAlex Zinenko void PySymbolTable::erase(PyOperationBase &symbol) {
229730d61893SAlex Zinenko   operation->checkValid();
229830d61893SAlex Zinenko   symbol.getOperation().checkValid();
229930d61893SAlex Zinenko   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
230030d61893SAlex Zinenko   // The operation is also erased, so we must invalidate it. There may be Python
230130d61893SAlex Zinenko   // references to this operation so we don't want to delete it from the list of
230230d61893SAlex Zinenko   // live operations here.
230330d61893SAlex Zinenko   symbol.getOperation().valid = false;
230430d61893SAlex Zinenko }
230530d61893SAlex Zinenko 
230630d61893SAlex Zinenko void PySymbolTable::dunderDel(const std::string &name) {
2307b56d1ec6SPeter Hawkins   nb::object operation = dunderGetItem(name);
2308b56d1ec6SPeter Hawkins   erase(nb::cast<PyOperationBase &>(operation));
230930d61893SAlex Zinenko }
231030d61893SAlex Zinenko 
2311974c1596SRahul Kayaith MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
231230d61893SAlex Zinenko   operation->checkValid();
231330d61893SAlex Zinenko   symbol.getOperation().checkValid();
231430d61893SAlex Zinenko   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
231530d61893SAlex Zinenko       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
231630d61893SAlex Zinenko   if (mlirAttributeIsNull(symbolAttr))
2317b56d1ec6SPeter Hawkins     throw nb::value_error("Expected operation to have a symbol name.");
2318974c1596SRahul Kayaith   return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
231930d61893SAlex Zinenko }
232030d61893SAlex Zinenko 
2321974c1596SRahul Kayaith MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
2322bdc31837SStella Laurenzo   // Op must already be a symbol.
2323bdc31837SStella Laurenzo   PyOperation &operation = symbol.getOperation();
2324bdc31837SStella Laurenzo   operation.checkValid();
2325bdc31837SStella Laurenzo   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
2326bdc31837SStella Laurenzo   MlirAttribute existingNameAttr =
2327bdc31837SStella Laurenzo       mlirOperationGetAttributeByName(operation.get(), attrName);
2328bdc31837SStella Laurenzo   if (mlirAttributeIsNull(existingNameAttr))
2329b56d1ec6SPeter Hawkins     throw nb::value_error("Expected operation to have a symbol name.");
2330974c1596SRahul Kayaith   return existingNameAttr;
2331bdc31837SStella Laurenzo }
2332bdc31837SStella Laurenzo 
2333bdc31837SStella Laurenzo void PySymbolTable::setSymbolName(PyOperationBase &symbol,
2334bdc31837SStella Laurenzo                                   const std::string &name) {
2335bdc31837SStella Laurenzo   // Op must already be a symbol.
2336bdc31837SStella Laurenzo   PyOperation &operation = symbol.getOperation();
2337bdc31837SStella Laurenzo   operation.checkValid();
2338bdc31837SStella Laurenzo   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
2339bdc31837SStella Laurenzo   MlirAttribute existingNameAttr =
2340bdc31837SStella Laurenzo       mlirOperationGetAttributeByName(operation.get(), attrName);
2341bdc31837SStella Laurenzo   if (mlirAttributeIsNull(existingNameAttr))
2342b56d1ec6SPeter Hawkins     throw nb::value_error("Expected operation to have a symbol name.");
2343bdc31837SStella Laurenzo   MlirAttribute newNameAttr =
2344bdc31837SStella Laurenzo       mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2345bdc31837SStella Laurenzo   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2346bdc31837SStella Laurenzo }
2347bdc31837SStella Laurenzo 
2348974c1596SRahul Kayaith MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
2349bdc31837SStella Laurenzo   PyOperation &operation = symbol.getOperation();
2350bdc31837SStella Laurenzo   operation.checkValid();
2351bdc31837SStella Laurenzo   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
2352bdc31837SStella Laurenzo   MlirAttribute existingVisAttr =
2353bdc31837SStella Laurenzo       mlirOperationGetAttributeByName(operation.get(), attrName);
2354bdc31837SStella Laurenzo   if (mlirAttributeIsNull(existingVisAttr))
2355b56d1ec6SPeter Hawkins     throw nb::value_error("Expected operation to have a symbol visibility.");
2356974c1596SRahul Kayaith   return existingVisAttr;
2357bdc31837SStella Laurenzo }
2358bdc31837SStella Laurenzo 
2359bdc31837SStella Laurenzo void PySymbolTable::setVisibility(PyOperationBase &symbol,
2360bdc31837SStella Laurenzo                                   const std::string &visibility) {
2361bdc31837SStella Laurenzo   if (visibility != "public" && visibility != "private" &&
2362bdc31837SStella Laurenzo       visibility != "nested")
2363b56d1ec6SPeter Hawkins     throw nb::value_error(
2364bdc31837SStella Laurenzo         "Expected visibility to be 'public', 'private' or 'nested'");
2365bdc31837SStella Laurenzo   PyOperation &operation = symbol.getOperation();
2366bdc31837SStella Laurenzo   operation.checkValid();
2367bdc31837SStella Laurenzo   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
2368bdc31837SStella Laurenzo   MlirAttribute existingVisAttr =
2369bdc31837SStella Laurenzo       mlirOperationGetAttributeByName(operation.get(), attrName);
2370bdc31837SStella Laurenzo   if (mlirAttributeIsNull(existingVisAttr))
2371b56d1ec6SPeter Hawkins     throw nb::value_error("Expected operation to have a symbol visibility.");
2372bdc31837SStella Laurenzo   MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2373bdc31837SStella Laurenzo                                                toMlirStringRef(visibility));
2374bdc31837SStella Laurenzo   mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2375bdc31837SStella Laurenzo }
2376bdc31837SStella Laurenzo 
2377bdc31837SStella Laurenzo void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2378bdc31837SStella Laurenzo                                          const std::string &newSymbol,
2379bdc31837SStella Laurenzo                                          PyOperationBase &from) {
2380bdc31837SStella Laurenzo   PyOperation &fromOperation = from.getOperation();
2381bdc31837SStella Laurenzo   fromOperation.checkValid();
2382bdc31837SStella Laurenzo   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
2383bdc31837SStella Laurenzo           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2384bdc31837SStella Laurenzo           from.getOperation())))
2385bdc31837SStella Laurenzo 
2386b56d1ec6SPeter Hawkins     throw nb::value_error("Symbol rename failed");
2387bdc31837SStella Laurenzo }
2388bdc31837SStella Laurenzo 
2389bdc31837SStella Laurenzo void PySymbolTable::walkSymbolTables(PyOperationBase &from,
2390bdc31837SStella Laurenzo                                      bool allSymUsesVisible,
2391b56d1ec6SPeter Hawkins                                      nb::object callback) {
2392bdc31837SStella Laurenzo   PyOperation &fromOperation = from.getOperation();
2393bdc31837SStella Laurenzo   fromOperation.checkValid();
2394bdc31837SStella Laurenzo   struct UserData {
2395bdc31837SStella Laurenzo     PyMlirContextRef context;
2396b56d1ec6SPeter Hawkins     nb::object callback;
2397bdc31837SStella Laurenzo     bool gotException;
2398bdc31837SStella Laurenzo     std::string exceptionWhat;
2399b56d1ec6SPeter Hawkins     nb::object exceptionType;
2400bdc31837SStella Laurenzo   };
2401bdc31837SStella Laurenzo   UserData userData{
2402bdc31837SStella Laurenzo       fromOperation.getContext(), std::move(callback), false, {}, {}};
2403bdc31837SStella Laurenzo   mlirSymbolTableWalkSymbolTables(
2404bdc31837SStella Laurenzo       fromOperation.get(), allSymUsesVisible,
2405bdc31837SStella Laurenzo       [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2406bdc31837SStella Laurenzo         UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2407bdc31837SStella Laurenzo         auto pyFoundOp =
2408bdc31837SStella Laurenzo             PyOperation::forOperation(calleeUserData->context, foundOp);
2409bdc31837SStella Laurenzo         if (calleeUserData->gotException)
2410bdc31837SStella Laurenzo           return;
2411bdc31837SStella Laurenzo         try {
2412bdc31837SStella Laurenzo           calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2413b56d1ec6SPeter Hawkins         } catch (nb::python_error &e) {
2414bdc31837SStella Laurenzo           calleeUserData->gotException = true;
2415bdc31837SStella Laurenzo           calleeUserData->exceptionWhat = e.what();
2416b56d1ec6SPeter Hawkins           calleeUserData->exceptionType = nb::borrow(e.type());
2417bdc31837SStella Laurenzo         }
2418bdc31837SStella Laurenzo       },
2419bdc31837SStella Laurenzo       static_cast<void *>(&userData));
2420bdc31837SStella Laurenzo   if (userData.gotException) {
2421bdc31837SStella Laurenzo     std::string message("Exception raised in callback: ");
2422bdc31837SStella Laurenzo     message.append(userData.exceptionWhat);
2423337c937dSMehdi Amini     throw std::runtime_error(message);
2424bdc31837SStella Laurenzo   }
2425bdc31837SStella Laurenzo }
2426bdc31837SStella Laurenzo 
2427436c6c9cSStella Laurenzo namespace {
2428436c6c9cSStella Laurenzo 
2429436c6c9cSStella Laurenzo /// Python wrapper for MlirBlockArgument.
2430436c6c9cSStella Laurenzo class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2431436c6c9cSStella Laurenzo public:
2432436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2433436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BlockArgument";
2434436c6c9cSStella Laurenzo   using PyConcreteValue::PyConcreteValue;
2435436c6c9cSStella Laurenzo 
2436436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
2437b56d1ec6SPeter Hawkins     c.def_prop_ro("owner", [](PyBlockArgument &self) {
2438436c6c9cSStella Laurenzo       return PyBlock(self.getParentOperation(),
2439436c6c9cSStella Laurenzo                      mlirBlockArgumentGetOwner(self.get()));
2440436c6c9cSStella Laurenzo     });
2441b56d1ec6SPeter Hawkins     c.def_prop_ro("arg_number", [](PyBlockArgument &self) {
2442436c6c9cSStella Laurenzo       return mlirBlockArgumentGetArgNumber(self.get());
2443436c6c9cSStella Laurenzo     });
2444a6e7d024SStella Laurenzo     c.def(
2445a6e7d024SStella Laurenzo         "set_type",
2446a6e7d024SStella Laurenzo         [](PyBlockArgument &self, PyType type) {
2447436c6c9cSStella Laurenzo           return mlirBlockArgumentSetType(self.get(), type);
2448a6e7d024SStella Laurenzo         },
2449b56d1ec6SPeter Hawkins         nb::arg("type"));
2450436c6c9cSStella Laurenzo   }
2451436c6c9cSStella Laurenzo };
2452436c6c9cSStella Laurenzo 
2453436c6c9cSStella Laurenzo /// A list of block arguments. Internally, these are stored as consecutive
2454436c6c9cSStella Laurenzo /// elements, random access is cheap. The argument list is associated with the
2455436c6c9cSStella Laurenzo /// operation that contains the block (detached blocks are not allowed in
2456436c6c9cSStella Laurenzo /// Python bindings) and extends its lifetime.
2457afeda4b9SAlex Zinenko class PyBlockArgumentList
2458afeda4b9SAlex Zinenko     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2459436c6c9cSStella Laurenzo public:
2460afeda4b9SAlex Zinenko   static constexpr const char *pyClassName = "BlockArgumentList";
24617c850867SMaksim Levental   using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
2462436c6c9cSStella Laurenzo 
2463afeda4b9SAlex Zinenko   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2464afeda4b9SAlex Zinenko                       intptr_t startIndex = 0, intptr_t length = -1,
2465afeda4b9SAlex Zinenko                       intptr_t step = 1)
2466afeda4b9SAlex Zinenko       : Sliceable(startIndex,
2467afeda4b9SAlex Zinenko                   length == -1 ? mlirBlockGetNumArguments(block) : length,
2468afeda4b9SAlex Zinenko                   step),
2469afeda4b9SAlex Zinenko         operation(std::move(operation)), block(block) {}
2470afeda4b9SAlex Zinenko 
2471ee168fb9SAlex Zinenko   static void bindDerived(ClassTy &c) {
2472b56d1ec6SPeter Hawkins     c.def_prop_ro("types", [](PyBlockArgumentList &self) {
2473ee168fb9SAlex Zinenko       return getValueTypes(self, self.operation->getContext());
2474ee168fb9SAlex Zinenko     });
2475ee168fb9SAlex Zinenko   }
2476ee168fb9SAlex Zinenko 
2477ee168fb9SAlex Zinenko private:
2478ee168fb9SAlex Zinenko   /// Give the parent CRTP class access to hook implementations below.
2479ee168fb9SAlex Zinenko   friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2480ee168fb9SAlex Zinenko 
2481afeda4b9SAlex Zinenko   /// Returns the number of arguments in the list.
2482ee168fb9SAlex Zinenko   intptr_t getRawNumElements() {
2483436c6c9cSStella Laurenzo     operation->checkValid();
2484436c6c9cSStella Laurenzo     return mlirBlockGetNumArguments(block);
2485436c6c9cSStella Laurenzo   }
2486436c6c9cSStella Laurenzo 
2487ee168fb9SAlex Zinenko   /// Returns `pos`-the element in the list.
2488ee168fb9SAlex Zinenko   PyBlockArgument getRawElement(intptr_t pos) {
2489afeda4b9SAlex Zinenko     MlirValue argument = mlirBlockGetArgument(block, pos);
2490afeda4b9SAlex Zinenko     return PyBlockArgument(operation, argument);
2491436c6c9cSStella Laurenzo   }
2492436c6c9cSStella Laurenzo 
2493afeda4b9SAlex Zinenko   /// Returns a sublist of this list.
2494afeda4b9SAlex Zinenko   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2495afeda4b9SAlex Zinenko                             intptr_t step) {
2496afeda4b9SAlex Zinenko     return PyBlockArgumentList(operation, block, startIndex, length, step);
2497436c6c9cSStella Laurenzo   }
2498436c6c9cSStella Laurenzo 
2499436c6c9cSStella Laurenzo   PyOperationRef operation;
2500436c6c9cSStella Laurenzo   MlirBlock block;
2501436c6c9cSStella Laurenzo };
2502436c6c9cSStella Laurenzo 
2503436c6c9cSStella Laurenzo /// A list of operation operands. Internally, these are stored as consecutive
2504d7e49736SMaksim Levental /// elements, random access is cheap. The (returned) operand list is associated
2505d7e49736SMaksim Levental /// with the operation whose operands these are, and thus extends the lifetime
2506d7e49736SMaksim Levental /// of this operation.
2507436c6c9cSStella Laurenzo class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2508436c6c9cSStella Laurenzo public:
2509436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpOperandList";
25107c850867SMaksim Levental   using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2511436c6c9cSStella Laurenzo 
2512436c6c9cSStella Laurenzo   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2513436c6c9cSStella Laurenzo                   intptr_t length = -1, intptr_t step = 1)
2514436c6c9cSStella Laurenzo       : Sliceable(startIndex,
2515436c6c9cSStella Laurenzo                   length == -1 ? mlirOperationGetNumOperands(operation->get())
2516436c6c9cSStella Laurenzo                                : length,
2517436c6c9cSStella Laurenzo                   step),
2518436c6c9cSStella Laurenzo         operation(operation) {}
2519436c6c9cSStella Laurenzo 
2520ee168fb9SAlex Zinenko   void dunderSetItem(intptr_t index, PyValue value) {
2521ee168fb9SAlex Zinenko     index = wrapIndex(index);
2522ee168fb9SAlex Zinenko     mlirOperationSetOperand(operation->get(), index, value.get());
2523ee168fb9SAlex Zinenko   }
2524ee168fb9SAlex Zinenko 
2525ee168fb9SAlex Zinenko   static void bindDerived(ClassTy &c) {
2526ee168fb9SAlex Zinenko     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2527ee168fb9SAlex Zinenko   }
2528ee168fb9SAlex Zinenko 
2529ee168fb9SAlex Zinenko private:
2530ee168fb9SAlex Zinenko   /// Give the parent CRTP class access to hook implementations below.
2531ee168fb9SAlex Zinenko   friend class Sliceable<PyOpOperandList, PyValue>;
2532ee168fb9SAlex Zinenko 
2533ee168fb9SAlex Zinenko   intptr_t getRawNumElements() {
2534436c6c9cSStella Laurenzo     operation->checkValid();
2535436c6c9cSStella Laurenzo     return mlirOperationGetNumOperands(operation->get());
2536436c6c9cSStella Laurenzo   }
2537436c6c9cSStella Laurenzo 
2538ee168fb9SAlex Zinenko   PyValue getRawElement(intptr_t pos) {
25395664c5e2SJohn Demme     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
25405664c5e2SJohn Demme     MlirOperation owner;
25415664c5e2SJohn Demme     if (mlirValueIsAOpResult(operand))
25425664c5e2SJohn Demme       owner = mlirOpResultGetOwner(operand);
25435664c5e2SJohn Demme     else if (mlirValueIsABlockArgument(operand))
25445664c5e2SJohn Demme       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
25455664c5e2SJohn Demme     else
25465664c5e2SJohn Demme       assert(false && "Value must be an block arg or op result.");
25475664c5e2SJohn Demme     PyOperationRef pyOwner =
25485664c5e2SJohn Demme         PyOperation::forOperation(operation->getContext(), owner);
25495664c5e2SJohn Demme     return PyValue(pyOwner, operand);
2550436c6c9cSStella Laurenzo   }
2551436c6c9cSStella Laurenzo 
2552436c6c9cSStella Laurenzo   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2553436c6c9cSStella Laurenzo     return PyOpOperandList(operation, startIndex, length, step);
2554436c6c9cSStella Laurenzo   }
2555436c6c9cSStella Laurenzo 
2556436c6c9cSStella Laurenzo   PyOperationRef operation;
2557436c6c9cSStella Laurenzo };
2558436c6c9cSStella Laurenzo 
2559d7e49736SMaksim Levental /// A list of operation successors. Internally, these are stored as consecutive
2560d7e49736SMaksim Levental /// elements, random access is cheap. The (returned) successor list is
2561d7e49736SMaksim Levental /// associated with the operation whose successors these are, and thus extends
2562d7e49736SMaksim Levental /// the lifetime of this operation.
2563d7e49736SMaksim Levental class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2564d7e49736SMaksim Levental public:
2565d7e49736SMaksim Levental   static constexpr const char *pyClassName = "OpSuccessors";
2566d7e49736SMaksim Levental 
2567d7e49736SMaksim Levental   PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2568d7e49736SMaksim Levental                  intptr_t length = -1, intptr_t step = 1)
2569d7e49736SMaksim Levental       : Sliceable(startIndex,
2570d7e49736SMaksim Levental                   length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2571d7e49736SMaksim Levental                                : length,
2572d7e49736SMaksim Levental                   step),
2573d7e49736SMaksim Levental         operation(operation) {}
2574d7e49736SMaksim Levental 
2575d7e49736SMaksim Levental   void dunderSetItem(intptr_t index, PyBlock block) {
2576d7e49736SMaksim Levental     index = wrapIndex(index);
2577d7e49736SMaksim Levental     mlirOperationSetSuccessor(operation->get(), index, block.get());
2578d7e49736SMaksim Levental   }
2579d7e49736SMaksim Levental 
2580d7e49736SMaksim Levental   static void bindDerived(ClassTy &c) {
2581d7e49736SMaksim Levental     c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
2582d7e49736SMaksim Levental   }
2583d7e49736SMaksim Levental 
2584d7e49736SMaksim Levental private:
2585d7e49736SMaksim Levental   /// Give the parent CRTP class access to hook implementations below.
2586d7e49736SMaksim Levental   friend class Sliceable<PyOpSuccessors, PyBlock>;
2587d7e49736SMaksim Levental 
2588d7e49736SMaksim Levental   intptr_t getRawNumElements() {
2589d7e49736SMaksim Levental     operation->checkValid();
2590d7e49736SMaksim Levental     return mlirOperationGetNumSuccessors(operation->get());
2591d7e49736SMaksim Levental   }
2592d7e49736SMaksim Levental 
2593d7e49736SMaksim Levental   PyBlock getRawElement(intptr_t pos) {
2594d7e49736SMaksim Levental     MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2595d7e49736SMaksim Levental     return PyBlock(operation, block);
2596d7e49736SMaksim Levental   }
2597d7e49736SMaksim Levental 
2598d7e49736SMaksim Levental   PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2599d7e49736SMaksim Levental     return PyOpSuccessors(operation, startIndex, length, step);
2600d7e49736SMaksim Levental   }
2601d7e49736SMaksim Levental 
2602d7e49736SMaksim Levental   PyOperationRef operation;
2603d7e49736SMaksim Levental };
2604d7e49736SMaksim Levental 
2605436c6c9cSStella Laurenzo /// A list of operation attributes. Can be indexed by name, producing
2606436c6c9cSStella Laurenzo /// attributes, or by index, producing named attributes.
2607436c6c9cSStella Laurenzo class PyOpAttributeMap {
2608436c6c9cSStella Laurenzo public:
26091fc096afSMehdi Amini   PyOpAttributeMap(PyOperationRef operation)
26101fc096afSMehdi Amini       : operation(std::move(operation)) {}
2611436c6c9cSStella Laurenzo 
2612974c1596SRahul Kayaith   MlirAttribute dunderGetItemNamed(const std::string &name) {
2613436c6c9cSStella Laurenzo     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2614436c6c9cSStella Laurenzo                                                          toMlirStringRef(name));
2615436c6c9cSStella Laurenzo     if (mlirAttributeIsNull(attr)) {
2616b56d1ec6SPeter Hawkins       throw nb::key_error("attempt to access a non-existent attribute");
2617436c6c9cSStella Laurenzo     }
2618974c1596SRahul Kayaith     return attr;
2619436c6c9cSStella Laurenzo   }
2620436c6c9cSStella Laurenzo 
2621436c6c9cSStella Laurenzo   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2622436c6c9cSStella Laurenzo     if (index < 0 || index >= dunderLen()) {
2623b56d1ec6SPeter Hawkins       throw nb::index_error("attempt to access out of bounds attribute");
2624436c6c9cSStella Laurenzo     }
2625436c6c9cSStella Laurenzo     MlirNamedAttribute namedAttr =
2626436c6c9cSStella Laurenzo         mlirOperationGetAttribute(operation->get(), index);
2627436c6c9cSStella Laurenzo     return PyNamedAttribute(
2628436c6c9cSStella Laurenzo         namedAttr.attribute,
2629120591e1SRiver Riddle         std::string(mlirIdentifierStr(namedAttr.name).data,
2630120591e1SRiver Riddle                     mlirIdentifierStr(namedAttr.name).length));
2631436c6c9cSStella Laurenzo   }
2632436c6c9cSStella Laurenzo 
26331fc096afSMehdi Amini   void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2634436c6c9cSStella Laurenzo     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2635436c6c9cSStella Laurenzo                                     attr);
2636436c6c9cSStella Laurenzo   }
2637436c6c9cSStella Laurenzo 
2638436c6c9cSStella Laurenzo   void dunderDelItem(const std::string &name) {
2639436c6c9cSStella Laurenzo     int removed = mlirOperationRemoveAttributeByName(operation->get(),
2640436c6c9cSStella Laurenzo                                                      toMlirStringRef(name));
2641436c6c9cSStella Laurenzo     if (!removed)
2642b56d1ec6SPeter Hawkins       throw nb::key_error("attempt to delete a non-existent attribute");
2643436c6c9cSStella Laurenzo   }
2644436c6c9cSStella Laurenzo 
2645436c6c9cSStella Laurenzo   intptr_t dunderLen() {
2646436c6c9cSStella Laurenzo     return mlirOperationGetNumAttributes(operation->get());
2647436c6c9cSStella Laurenzo   }
2648436c6c9cSStella Laurenzo 
2649436c6c9cSStella Laurenzo   bool dunderContains(const std::string &name) {
2650436c6c9cSStella Laurenzo     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2651436c6c9cSStella Laurenzo         operation->get(), toMlirStringRef(name)));
2652436c6c9cSStella Laurenzo   }
2653436c6c9cSStella Laurenzo 
2654b56d1ec6SPeter Hawkins   static void bind(nb::module_ &m) {
2655b56d1ec6SPeter Hawkins     nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
2656436c6c9cSStella Laurenzo         .def("__contains__", &PyOpAttributeMap::dunderContains)
2657436c6c9cSStella Laurenzo         .def("__len__", &PyOpAttributeMap::dunderLen)
2658436c6c9cSStella Laurenzo         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2659436c6c9cSStella Laurenzo         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2660436c6c9cSStella Laurenzo         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2661436c6c9cSStella Laurenzo         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2662436c6c9cSStella Laurenzo   }
2663436c6c9cSStella Laurenzo 
2664436c6c9cSStella Laurenzo private:
2665436c6c9cSStella Laurenzo   PyOperationRef operation;
2666436c6c9cSStella Laurenzo };
2667436c6c9cSStella Laurenzo 
2668be0a7e9fSMehdi Amini } // namespace
2669436c6c9cSStella Laurenzo 
2670436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
2671436c6c9cSStella Laurenzo // Populates the core exports of the 'ir' submodule.
2672436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
2673436c6c9cSStella Laurenzo 
2674b56d1ec6SPeter Hawkins void mlir::python::populateIRCore(nb::module_ &m) {
267508e2c15aSMaksim Levental   // disable leak warnings which tend to be false positives.
267608e2c15aSMaksim Levental   nb::set_leak_warnings(false);
2677436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
26787ee25bc5SStella Laurenzo   // Enums.
26797ee25bc5SStella Laurenzo   //----------------------------------------------------------------------------
2680b56d1ec6SPeter Hawkins   nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
26817ee25bc5SStella Laurenzo       .value("ERROR", MlirDiagnosticError)
26827ee25bc5SStella Laurenzo       .value("WARNING", MlirDiagnosticWarning)
26837ee25bc5SStella Laurenzo       .value("NOTE", MlirDiagnosticNote)
26847ee25bc5SStella Laurenzo       .value("REMARK", MlirDiagnosticRemark);
26857ee25bc5SStella Laurenzo 
2686b56d1ec6SPeter Hawkins   nb::enum_<MlirWalkOrder>(m, "WalkOrder")
268747148832SHideto Ueno       .value("PRE_ORDER", MlirWalkPreOrder)
268847148832SHideto Ueno       .value("POST_ORDER", MlirWalkPostOrder);
268947148832SHideto Ueno 
2690b56d1ec6SPeter Hawkins   nb::enum_<MlirWalkResult>(m, "WalkResult")
269147148832SHideto Ueno       .value("ADVANCE", MlirWalkResultAdvance)
269247148832SHideto Ueno       .value("INTERRUPT", MlirWalkResultInterrupt)
269347148832SHideto Ueno       .value("SKIP", MlirWalkResultSkip);
269447148832SHideto Ueno 
26957ee25bc5SStella Laurenzo   //----------------------------------------------------------------------------
26967ee25bc5SStella Laurenzo   // Mapping of Diagnostics.
26977ee25bc5SStella Laurenzo   //----------------------------------------------------------------------------
2698b56d1ec6SPeter Hawkins   nb::class_<PyDiagnostic>(m, "Diagnostic")
2699b56d1ec6SPeter Hawkins       .def_prop_ro("severity", &PyDiagnostic::getSeverity)
2700b56d1ec6SPeter Hawkins       .def_prop_ro("location", &PyDiagnostic::getLocation)
2701b56d1ec6SPeter Hawkins       .def_prop_ro("message", &PyDiagnostic::getMessage)
2702b56d1ec6SPeter Hawkins       .def_prop_ro("notes", &PyDiagnostic::getNotes)
2703b56d1ec6SPeter Hawkins       .def("__str__", [](PyDiagnostic &self) -> nb::str {
27047ee25bc5SStella Laurenzo         if (!self.isValid())
2705b56d1ec6SPeter Hawkins           return nb::str("<Invalid Diagnostic>");
27067ee25bc5SStella Laurenzo         return self.getMessage();
27077ee25bc5SStella Laurenzo       });
27087ee25bc5SStella Laurenzo 
2709b56d1ec6SPeter Hawkins   nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
2710b56d1ec6SPeter Hawkins       .def("__init__",
2711b56d1ec6SPeter Hawkins            [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
2712b56d1ec6SPeter Hawkins              new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
2713b56d1ec6SPeter Hawkins            })
2714b56d1ec6SPeter Hawkins       .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity)
2715b56d1ec6SPeter Hawkins       .def_ro("location", &PyDiagnostic::DiagnosticInfo::location)
2716b56d1ec6SPeter Hawkins       .def_ro("message", &PyDiagnostic::DiagnosticInfo::message)
2717b56d1ec6SPeter Hawkins       .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes)
27183ea4c501SRahul Kayaith       .def("__str__",
27193ea4c501SRahul Kayaith            [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
27203ea4c501SRahul Kayaith 
2721b56d1ec6SPeter Hawkins   nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
27227ee25bc5SStella Laurenzo       .def("detach", &PyDiagnosticHandler::detach)
2723b56d1ec6SPeter Hawkins       .def_prop_ro("attached", &PyDiagnosticHandler::isAttached)
2724b56d1ec6SPeter Hawkins       .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError)
27257ee25bc5SStella Laurenzo       .def("__enter__", &PyDiagnosticHandler::contextEnter)
2726b56d1ec6SPeter Hawkins       .def("__exit__", &PyDiagnosticHandler::contextExit,
2727b56d1ec6SPeter Hawkins            nb::arg("exc_type").none(), nb::arg("exc_value").none(),
2728b56d1ec6SPeter Hawkins            nb::arg("traceback").none());
27297ee25bc5SStella Laurenzo 
27307ee25bc5SStella Laurenzo   //----------------------------------------------------------------------------
27314acd8457SAlex Zinenko   // Mapping of MlirContext.
27325e83a5b4SStella Laurenzo   // Note that this is exported as _BaseContext. The containing, Python level
27335e83a5b4SStella Laurenzo   // __init__.py will subclass it with site-specific functionality and set a
27345e83a5b4SStella Laurenzo   // "Context" attribute on this module.
2735436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2736b56d1ec6SPeter Hawkins   nb::class_<PyMlirContext>(m, "_BaseContext")
2737b56d1ec6SPeter Hawkins       .def("__init__",
2738b56d1ec6SPeter Hawkins            [](PyMlirContext &self) {
2739b56d1ec6SPeter Hawkins              MlirContext context = mlirContextCreateWithThreading(false);
2740b56d1ec6SPeter Hawkins              new (&self) PyMlirContext(context);
2741b56d1ec6SPeter Hawkins            })
2742436c6c9cSStella Laurenzo       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2743436c6c9cSStella Laurenzo       .def("_get_context_again",
2744436c6c9cSStella Laurenzo            [](PyMlirContext &self) {
2745436c6c9cSStella Laurenzo              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2746436c6c9cSStella Laurenzo              return ref.releaseObject();
2747436c6c9cSStella Laurenzo            })
2748436c6c9cSStella Laurenzo       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2749d1fdb416SJohn Demme       .def("_get_live_operation_objects",
2750d1fdb416SJohn Demme            &PyMlirContext::getLiveOperationObjects)
27516b0bed7eSJohn Demme       .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
275291f11611SOleksandr "Alex" Zinenko       .def("_clear_live_operations_inside",
2753b56d1ec6SPeter Hawkins            nb::overload_cast<MlirOperation>(
275491f11611SOleksandr "Alex" Zinenko                &PyMlirContext::clearOperationsInside))
2755436c6c9cSStella Laurenzo       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2756b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
2757436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2758436c6c9cSStella Laurenzo       .def("__enter__", &PyMlirContext::contextEnter)
2759b56d1ec6SPeter Hawkins       .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
2760b56d1ec6SPeter Hawkins            nb::arg("exc_value").none(), nb::arg("traceback").none())
2761b56d1ec6SPeter Hawkins       .def_prop_ro_static(
2762436c6c9cSStella Laurenzo           "current",
2763b56d1ec6SPeter Hawkins           [](nb::object & /*class*/) {
2764436c6c9cSStella Laurenzo             auto *context = PyThreadContextEntry::getDefaultContext();
2765436c6c9cSStella Laurenzo             if (!context)
2766b56d1ec6SPeter Hawkins               return nb::none();
2767b56d1ec6SPeter Hawkins             return nb::cast(context);
2768436c6c9cSStella Laurenzo           },
2769436c6c9cSStella Laurenzo           "Gets the Context bound to the current thread or raises ValueError")
2770b56d1ec6SPeter Hawkins       .def_prop_ro(
2771436c6c9cSStella Laurenzo           "dialects",
2772436c6c9cSStella Laurenzo           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2773436c6c9cSStella Laurenzo           "Gets a container for accessing dialects by name")
2774b56d1ec6SPeter Hawkins       .def_prop_ro(
2775436c6c9cSStella Laurenzo           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2776436c6c9cSStella Laurenzo           "Alias for 'dialect'")
2777436c6c9cSStella Laurenzo       .def(
2778436c6c9cSStella Laurenzo           "get_dialect_descriptor",
2779436c6c9cSStella Laurenzo           [=](PyMlirContext &self, std::string &name) {
2780436c6c9cSStella Laurenzo             MlirDialect dialect = mlirContextGetOrLoadDialect(
2781436c6c9cSStella Laurenzo                 self.get(), {name.data(), name.size()});
2782436c6c9cSStella Laurenzo             if (mlirDialectIsNull(dialect)) {
2783b56d1ec6SPeter Hawkins               throw nb::value_error(
2784b56d1ec6SPeter Hawkins                   (Twine("Dialect '") + name + "' not found").str().c_str());
2785436c6c9cSStella Laurenzo             }
2786436c6c9cSStella Laurenzo             return PyDialectDescriptor(self.getRef(), dialect);
2787436c6c9cSStella Laurenzo           },
2788b56d1ec6SPeter Hawkins           nb::arg("dialect_name"),
2789436c6c9cSStella Laurenzo           "Gets or loads a dialect by name, returning its descriptor object")
2790b56d1ec6SPeter Hawkins       .def_prop_rw(
2791436c6c9cSStella Laurenzo           "allow_unregistered_dialects",
2792436c6c9cSStella Laurenzo           [](PyMlirContext &self) -> bool {
2793436c6c9cSStella Laurenzo             return mlirContextGetAllowUnregisteredDialects(self.get());
2794436c6c9cSStella Laurenzo           },
2795436c6c9cSStella Laurenzo           [](PyMlirContext &self, bool value) {
2796436c6c9cSStella Laurenzo             mlirContextSetAllowUnregisteredDialects(self.get(), value);
27979a9214faSStella Laurenzo           })
27987ee25bc5SStella Laurenzo       .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2799b56d1ec6SPeter Hawkins            nb::arg("callback"),
28007ee25bc5SStella Laurenzo            "Attaches a diagnostic handler that will receive callbacks")
2801a6e7d024SStella Laurenzo       .def(
2802a6e7d024SStella Laurenzo           "enable_multithreading",
2803caa159f0SNicolas Vasilache           [](PyMlirContext &self, bool enable) {
2804caa159f0SNicolas Vasilache             mlirContextEnableMultithreading(self.get(), enable);
2805a6e7d024SStella Laurenzo           },
2806b56d1ec6SPeter Hawkins           nb::arg("enable"))
2807a6e7d024SStella Laurenzo       .def(
2808a6e7d024SStella Laurenzo           "is_registered_operation",
28099a9214faSStella Laurenzo           [](PyMlirContext &self, std::string &name) {
28109a9214faSStella Laurenzo             return mlirContextIsRegisteredOperation(
28119a9214faSStella Laurenzo                 self.get(), MlirStringRef{name.data(), name.size()});
2812a6e7d024SStella Laurenzo           },
2813b56d1ec6SPeter Hawkins           nb::arg("operation_name"))
28145e83a5b4SStella Laurenzo       .def(
28155e83a5b4SStella Laurenzo           "append_dialect_registry",
28165e83a5b4SStella Laurenzo           [](PyMlirContext &self, PyDialectRegistry &registry) {
28175e83a5b4SStella Laurenzo             mlirContextAppendDialectRegistry(self.get(), registry);
28185e83a5b4SStella Laurenzo           },
2819b56d1ec6SPeter Hawkins           nb::arg("registry"))
2820b56d1ec6SPeter Hawkins       .def_prop_rw("emit_error_diagnostics", nullptr,
28213ea4c501SRahul Kayaith                    &PyMlirContext::setEmitErrorDiagnostics,
28223ea4c501SRahul Kayaith                    "Emit error diagnostics to diagnostic handlers. By default "
28233ea4c501SRahul Kayaith                    "error diagnostics are captured and reported through "
28243ea4c501SRahul Kayaith                    "MLIRError exceptions.")
28255e83a5b4SStella Laurenzo       .def("load_all_available_dialects", [](PyMlirContext &self) {
28265e83a5b4SStella Laurenzo         mlirContextLoadAllAvailableDialects(self.get());
28275e83a5b4SStella Laurenzo       });
2828436c6c9cSStella Laurenzo 
2829436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2830436c6c9cSStella Laurenzo   // Mapping of PyDialectDescriptor
2831436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2832b56d1ec6SPeter Hawkins   nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
2833b56d1ec6SPeter Hawkins       .def_prop_ro("namespace",
2834436c6c9cSStella Laurenzo                    [](PyDialectDescriptor &self) {
2835b56d1ec6SPeter Hawkins                      MlirStringRef ns = mlirDialectGetNamespace(self.get());
2836b56d1ec6SPeter Hawkins                      return nb::str(ns.data, ns.length);
2837436c6c9cSStella Laurenzo                    })
2838436c6c9cSStella Laurenzo       .def("__repr__", [](PyDialectDescriptor &self) {
2839436c6c9cSStella Laurenzo         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2840436c6c9cSStella Laurenzo         std::string repr("<DialectDescriptor ");
2841436c6c9cSStella Laurenzo         repr.append(ns.data, ns.length);
2842436c6c9cSStella Laurenzo         repr.append(">");
2843436c6c9cSStella Laurenzo         return repr;
2844436c6c9cSStella Laurenzo       });
2845436c6c9cSStella Laurenzo 
2846436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2847436c6c9cSStella Laurenzo   // Mapping of PyDialects
2848436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2849b56d1ec6SPeter Hawkins   nb::class_<PyDialects>(m, "Dialects")
2850436c6c9cSStella Laurenzo       .def("__getitem__",
2851436c6c9cSStella Laurenzo            [=](PyDialects &self, std::string keyName) {
2852436c6c9cSStella Laurenzo              MlirDialect dialect =
2853436c6c9cSStella Laurenzo                  self.getDialectForKey(keyName, /*attrError=*/false);
2854b56d1ec6SPeter Hawkins              nb::object descriptor =
2855b56d1ec6SPeter Hawkins                  nb::cast(PyDialectDescriptor{self.getContext(), dialect});
2856436c6c9cSStella Laurenzo              return createCustomDialectWrapper(keyName, std::move(descriptor));
2857436c6c9cSStella Laurenzo            })
2858436c6c9cSStella Laurenzo       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2859436c6c9cSStella Laurenzo         MlirDialect dialect =
2860436c6c9cSStella Laurenzo             self.getDialectForKey(attrName, /*attrError=*/true);
2861b56d1ec6SPeter Hawkins         nb::object descriptor =
2862b56d1ec6SPeter Hawkins             nb::cast(PyDialectDescriptor{self.getContext(), dialect});
2863436c6c9cSStella Laurenzo         return createCustomDialectWrapper(attrName, std::move(descriptor));
2864436c6c9cSStella Laurenzo       });
2865436c6c9cSStella Laurenzo 
2866436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2867436c6c9cSStella Laurenzo   // Mapping of PyDialect
2868436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2869b56d1ec6SPeter Hawkins   nb::class_<PyDialect>(m, "Dialect")
2870b56d1ec6SPeter Hawkins       .def(nb::init<nb::object>(), nb::arg("descriptor"))
2871b56d1ec6SPeter Hawkins       .def_prop_ro("descriptor",
2872b56d1ec6SPeter Hawkins                    [](PyDialect &self) { return self.getDescriptor(); })
2873b56d1ec6SPeter Hawkins       .def("__repr__", [](nb::object self) {
2874436c6c9cSStella Laurenzo         auto clazz = self.attr("__class__");
2875b56d1ec6SPeter Hawkins         return nb::str("<Dialect ") +
2876b56d1ec6SPeter Hawkins                self.attr("descriptor").attr("namespace") + nb::str(" (class ") +
2877b56d1ec6SPeter Hawkins                clazz.attr("__module__") + nb::str(".") +
2878b56d1ec6SPeter Hawkins                clazz.attr("__name__") + nb::str(")>");
2879436c6c9cSStella Laurenzo       });
2880436c6c9cSStella Laurenzo 
2881436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
28825e83a5b4SStella Laurenzo   // Mapping of PyDialectRegistry
28835e83a5b4SStella Laurenzo   //----------------------------------------------------------------------------
2884b56d1ec6SPeter Hawkins   nb::class_<PyDialectRegistry>(m, "DialectRegistry")
2885b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
28865e83a5b4SStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2887b56d1ec6SPeter Hawkins       .def(nb::init<>());
28885e83a5b4SStella Laurenzo 
28895e83a5b4SStella Laurenzo   //----------------------------------------------------------------------------
2890436c6c9cSStella Laurenzo   // Mapping of Location
2891436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2892b56d1ec6SPeter Hawkins   nb::class_<PyLocation>(m, "Location")
2893b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2894436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2895436c6c9cSStella Laurenzo       .def("__enter__", &PyLocation::contextEnter)
2896b56d1ec6SPeter Hawkins       .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
2897b56d1ec6SPeter Hawkins            nb::arg("exc_value").none(), nb::arg("traceback").none())
2898436c6c9cSStella Laurenzo       .def("__eq__",
2899436c6c9cSStella Laurenzo            [](PyLocation &self, PyLocation &other) -> bool {
2900436c6c9cSStella Laurenzo              return mlirLocationEqual(self, other);
2901436c6c9cSStella Laurenzo            })
2902b56d1ec6SPeter Hawkins       .def("__eq__", [](PyLocation &self, nb::object other) { return false; })
2903b56d1ec6SPeter Hawkins       .def_prop_ro_static(
2904436c6c9cSStella Laurenzo           "current",
2905b56d1ec6SPeter Hawkins           [](nb::object & /*class*/) {
2906436c6c9cSStella Laurenzo             auto *loc = PyThreadContextEntry::getDefaultLocation();
2907436c6c9cSStella Laurenzo             if (!loc)
2908b56d1ec6SPeter Hawkins               throw nb::value_error("No current Location");
2909436c6c9cSStella Laurenzo             return loc;
2910436c6c9cSStella Laurenzo           },
2911436c6c9cSStella Laurenzo           "Gets the Location bound to the current thread or raises ValueError")
2912436c6c9cSStella Laurenzo       .def_static(
2913436c6c9cSStella Laurenzo           "unknown",
2914436c6c9cSStella Laurenzo           [](DefaultingPyMlirContext context) {
2915436c6c9cSStella Laurenzo             return PyLocation(context->getRef(),
2916436c6c9cSStella Laurenzo                               mlirLocationUnknownGet(context->get()));
2917436c6c9cSStella Laurenzo           },
2918b56d1ec6SPeter Hawkins           nb::arg("context").none() = nb::none(),
2919436c6c9cSStella Laurenzo           "Gets a Location representing an unknown location")
2920436c6c9cSStella Laurenzo       .def_static(
2921e67cbbefSJacques Pienaar           "callsite",
2922e67cbbefSJacques Pienaar           [](PyLocation callee, const std::vector<PyLocation> &frames,
2923e67cbbefSJacques Pienaar              DefaultingPyMlirContext context) {
2924e67cbbefSJacques Pienaar             if (frames.empty())
2925b56d1ec6SPeter Hawkins               throw nb::value_error("No caller frames provided");
2926e67cbbefSJacques Pienaar             MlirLocation caller = frames.back().get();
2927e2f16be5SMehdi Amini             for (const PyLocation &frame :
2928984b800aSserge-sans-paille                  llvm::reverse(llvm::ArrayRef(frames).drop_back()))
2929e67cbbefSJacques Pienaar               caller = mlirLocationCallSiteGet(frame.get(), caller);
2930e67cbbefSJacques Pienaar             return PyLocation(context->getRef(),
2931e67cbbefSJacques Pienaar                               mlirLocationCallSiteGet(callee.get(), caller));
2932e67cbbefSJacques Pienaar           },
2933b56d1ec6SPeter Hawkins           nb::arg("callee"), nb::arg("frames"),
2934b56d1ec6SPeter Hawkins           nb::arg("context").none() = nb::none(),
2935e67cbbefSJacques Pienaar           kContextGetCallSiteLocationDocstring)
2936e67cbbefSJacques Pienaar       .def_static(
2937436c6c9cSStella Laurenzo           "file",
2938436c6c9cSStella Laurenzo           [](std::string filename, int line, int col,
2939436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
2940436c6c9cSStella Laurenzo             return PyLocation(
2941436c6c9cSStella Laurenzo                 context->getRef(),
2942436c6c9cSStella Laurenzo                 mlirLocationFileLineColGet(
2943436c6c9cSStella Laurenzo                     context->get(), toMlirStringRef(filename), line, col));
2944436c6c9cSStella Laurenzo           },
2945b56d1ec6SPeter Hawkins           nb::arg("filename"), nb::arg("line"), nb::arg("col"),
2946b56d1ec6SPeter Hawkins           nb::arg("context").none() = nb::none(),
2947b56d1ec6SPeter Hawkins           kContextGetFileLocationDocstring)
294804d76d36SJacques Pienaar       .def_static(
2949a77250fdSJacques Pienaar           "file",
2950a77250fdSJacques Pienaar           [](std::string filename, int startLine, int startCol, int endLine,
2951a77250fdSJacques Pienaar              int endCol, DefaultingPyMlirContext context) {
2952a77250fdSJacques Pienaar             return PyLocation(context->getRef(),
2953a77250fdSJacques Pienaar                               mlirLocationFileLineColRangeGet(
2954a77250fdSJacques Pienaar                                   context->get(), toMlirStringRef(filename),
2955a77250fdSJacques Pienaar                                   startLine, startCol, endLine, endCol));
2956a77250fdSJacques Pienaar           },
2957a77250fdSJacques Pienaar           nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
2958a77250fdSJacques Pienaar           nb::arg("end_line"), nb::arg("end_col"),
2959a77250fdSJacques Pienaar           nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring)
2960a77250fdSJacques Pienaar       .def_static(
29611ab3efacSJacques Pienaar           "fused",
29627ee25bc5SStella Laurenzo           [](const std::vector<PyLocation> &pyLocations,
29630a81ace0SKazu Hirata              std::optional<PyAttribute> metadata,
29641ab3efacSJacques Pienaar              DefaultingPyMlirContext context) {
29651ab3efacSJacques Pienaar             llvm::SmallVector<MlirLocation, 4> locations;
29661ab3efacSJacques Pienaar             locations.reserve(pyLocations.size());
29671ab3efacSJacques Pienaar             for (auto &pyLocation : pyLocations)
29681ab3efacSJacques Pienaar               locations.push_back(pyLocation.get());
29691ab3efacSJacques Pienaar             MlirLocation location = mlirLocationFusedGet(
29701ab3efacSJacques Pienaar                 context->get(), locations.size(), locations.data(),
29711ab3efacSJacques Pienaar                 metadata ? metadata->get() : MlirAttribute{0});
29721ab3efacSJacques Pienaar             return PyLocation(context->getRef(), location);
29731ab3efacSJacques Pienaar           },
2974b56d1ec6SPeter Hawkins           nb::arg("locations"), nb::arg("metadata").none() = nb::none(),
2975b56d1ec6SPeter Hawkins           nb::arg("context").none() = nb::none(),
2976b56d1ec6SPeter Hawkins           kContextGetFusedLocationDocstring)
29771ab3efacSJacques Pienaar       .def_static(
297804d76d36SJacques Pienaar           "name",
29790a81ace0SKazu Hirata           [](std::string name, std::optional<PyLocation> childLoc,
298004d76d36SJacques Pienaar              DefaultingPyMlirContext context) {
298104d76d36SJacques Pienaar             return PyLocation(
298204d76d36SJacques Pienaar                 context->getRef(),
298304d76d36SJacques Pienaar                 mlirLocationNameGet(
298404d76d36SJacques Pienaar                     context->get(), toMlirStringRef(name),
298504d76d36SJacques Pienaar                     childLoc ? childLoc->get()
298604d76d36SJacques Pienaar                              : mlirLocationUnknownGet(context->get())));
298704d76d36SJacques Pienaar           },
2988b56d1ec6SPeter Hawkins           nb::arg("name"), nb::arg("childLoc").none() = nb::none(),
2989b56d1ec6SPeter Hawkins           nb::arg("context").none() = nb::none(),
2990b56d1ec6SPeter Hawkins           kContextGetNameLocationDocString)
2991792f3c81SAndrew Young       .def_static(
2992792f3c81SAndrew Young           "from_attr",
2993792f3c81SAndrew Young           [](PyAttribute &attribute, DefaultingPyMlirContext context) {
2994792f3c81SAndrew Young             return PyLocation(context->getRef(),
2995792f3c81SAndrew Young                               mlirLocationFromAttribute(attribute));
2996792f3c81SAndrew Young           },
2997b56d1ec6SPeter Hawkins           nb::arg("attribute"), nb::arg("context").none() = nb::none(),
2998792f3c81SAndrew Young           "Gets a Location from a LocationAttr")
2999b56d1ec6SPeter Hawkins       .def_prop_ro(
3000436c6c9cSStella Laurenzo           "context",
3001436c6c9cSStella Laurenzo           [](PyLocation &self) { return self.getContext().getObject(); },
3002436c6c9cSStella Laurenzo           "Context that owns the Location")
3003b56d1ec6SPeter Hawkins       .def_prop_ro(
3004792f3c81SAndrew Young           "attr",
30059566ee28Smax           [](PyLocation &self) { return mlirLocationGetAttribute(self); },
3006792f3c81SAndrew Young           "Get the underlying LocationAttr")
30077ee25bc5SStella Laurenzo       .def(
30087ee25bc5SStella Laurenzo           "emit_error",
30097ee25bc5SStella Laurenzo           [](PyLocation &self, std::string message) {
30107ee25bc5SStella Laurenzo             mlirEmitError(self, message.c_str());
30117ee25bc5SStella Laurenzo           },
3012b56d1ec6SPeter Hawkins           nb::arg("message"), "Emits an error at this location")
3013436c6c9cSStella Laurenzo       .def("__repr__", [](PyLocation &self) {
3014436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
3015436c6c9cSStella Laurenzo         mlirLocationPrint(self, printAccum.getCallback(),
3016436c6c9cSStella Laurenzo                           printAccum.getUserData());
3017436c6c9cSStella Laurenzo         return printAccum.join();
3018436c6c9cSStella Laurenzo       });
3019436c6c9cSStella Laurenzo 
3020436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3021436c6c9cSStella Laurenzo   // Mapping of Module
3022436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3023b56d1ec6SPeter Hawkins   nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
3024b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
3025436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
3026436c6c9cSStella Laurenzo       .def_static(
3027436c6c9cSStella Laurenzo           "parse",
30283ea4c501SRahul Kayaith           [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
30293ea4c501SRahul Kayaith             PyMlirContext::ErrorCapture errors(context->getRef());
3030436c6c9cSStella Laurenzo             MlirModule module = mlirModuleCreateParse(
3031436c6c9cSStella Laurenzo                 context->get(), toMlirStringRef(moduleAsm));
30323ea4c501SRahul Kayaith             if (mlirModuleIsNull(module))
30333ea4c501SRahul Kayaith               throw MLIRError("Unable to parse module assembly", errors.take());
3034436c6c9cSStella Laurenzo             return PyModule::forModule(module).releaseObject();
3035436c6c9cSStella Laurenzo           },
3036b56d1ec6SPeter Hawkins           nb::arg("asm"), nb::arg("context").none() = nb::none(),
3037b56d1ec6SPeter Hawkins           kModuleParseDocstring)
3038b56d1ec6SPeter Hawkins       .def_static(
3039b56d1ec6SPeter Hawkins           "parse",
3040b56d1ec6SPeter Hawkins           [](nb::bytes moduleAsm, DefaultingPyMlirContext context) {
3041b56d1ec6SPeter Hawkins             PyMlirContext::ErrorCapture errors(context->getRef());
3042b56d1ec6SPeter Hawkins             MlirModule module = mlirModuleCreateParse(
3043b56d1ec6SPeter Hawkins                 context->get(), toMlirStringRef(moduleAsm));
3044b56d1ec6SPeter Hawkins             if (mlirModuleIsNull(module))
3045b56d1ec6SPeter Hawkins               throw MLIRError("Unable to parse module assembly", errors.take());
3046b56d1ec6SPeter Hawkins             return PyModule::forModule(module).releaseObject();
3047b56d1ec6SPeter Hawkins           },
3048b56d1ec6SPeter Hawkins           nb::arg("asm"), nb::arg("context").none() = nb::none(),
3049436c6c9cSStella Laurenzo           kModuleParseDocstring)
3050436c6c9cSStella Laurenzo       .def_static(
3051436c6c9cSStella Laurenzo           "create",
3052436c6c9cSStella Laurenzo           [](DefaultingPyLocation loc) {
3053436c6c9cSStella Laurenzo             MlirModule module = mlirModuleCreateEmpty(loc);
3054436c6c9cSStella Laurenzo             return PyModule::forModule(module).releaseObject();
3055436c6c9cSStella Laurenzo           },
3056b56d1ec6SPeter Hawkins           nb::arg("loc").none() = nb::none(), "Creates an empty module")
3057b56d1ec6SPeter Hawkins       .def_prop_ro(
3058436c6c9cSStella Laurenzo           "context",
3059436c6c9cSStella Laurenzo           [](PyModule &self) { return self.getContext().getObject(); },
3060436c6c9cSStella Laurenzo           "Context that created the Module")
3061b56d1ec6SPeter Hawkins       .def_prop_ro(
3062436c6c9cSStella Laurenzo           "operation",
3063436c6c9cSStella Laurenzo           [](PyModule &self) {
3064436c6c9cSStella Laurenzo             return PyOperation::forOperation(self.getContext(),
3065436c6c9cSStella Laurenzo                                              mlirModuleGetOperation(self.get()),
3066436c6c9cSStella Laurenzo                                              self.getRef().releaseObject())
3067436c6c9cSStella Laurenzo                 .releaseObject();
3068436c6c9cSStella Laurenzo           },
3069436c6c9cSStella Laurenzo           "Accesses the module as an operation")
3070b56d1ec6SPeter Hawkins       .def_prop_ro(
3071436c6c9cSStella Laurenzo           "body",
3072436c6c9cSStella Laurenzo           [](PyModule &self) {
307302b6fb21SMehdi Amini             PyOperationRef moduleOp = PyOperation::forOperation(
3074436c6c9cSStella Laurenzo                 self.getContext(), mlirModuleGetOperation(self.get()),
3075436c6c9cSStella Laurenzo                 self.getRef().releaseObject());
307602b6fb21SMehdi Amini             PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
3077436c6c9cSStella Laurenzo             return returnBlock;
3078436c6c9cSStella Laurenzo           },
3079436c6c9cSStella Laurenzo           "Return the block for this module")
3080436c6c9cSStella Laurenzo       .def(
3081436c6c9cSStella Laurenzo           "dump",
3082436c6c9cSStella Laurenzo           [](PyModule &self) {
3083436c6c9cSStella Laurenzo             mlirOperationDump(mlirModuleGetOperation(self.get()));
3084436c6c9cSStella Laurenzo           },
3085436c6c9cSStella Laurenzo           kDumpDocstring)
3086436c6c9cSStella Laurenzo       .def(
3087436c6c9cSStella Laurenzo           "__str__",
3088b56d1ec6SPeter Hawkins           [](nb::object self) {
3089ace1d0adSStella Laurenzo             // Defer to the operation's __str__.
3090ace1d0adSStella Laurenzo             return self.attr("operation").attr("__str__")();
3091436c6c9cSStella Laurenzo           },
3092436c6c9cSStella Laurenzo           kOperationStrDunderDocstring);
3093436c6c9cSStella Laurenzo 
3094436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3095436c6c9cSStella Laurenzo   // Mapping of Operation.
3096436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3097b56d1ec6SPeter Hawkins   nb::class_<PyOperationBase>(m, "_OperationBase")
3098b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
30991fb2e842SStella Laurenzo                    [](PyOperationBase &self) {
31001fb2e842SStella Laurenzo                      return self.getOperation().getCapsule();
31011fb2e842SStella Laurenzo                    })
3102436c6c9cSStella Laurenzo       .def("__eq__",
3103436c6c9cSStella Laurenzo            [](PyOperationBase &self, PyOperationBase &other) {
3104436c6c9cSStella Laurenzo              return &self.getOperation() == &other.getOperation();
3105436c6c9cSStella Laurenzo            })
3106436c6c9cSStella Laurenzo       .def("__eq__",
3107b56d1ec6SPeter Hawkins            [](PyOperationBase &self, nb::object other) { return false; })
3108f78fe0b7Srkayaith       .def("__hash__",
3109f78fe0b7Srkayaith            [](PyOperationBase &self) {
3110f78fe0b7Srkayaith              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
3111f78fe0b7Srkayaith            })
3112b56d1ec6SPeter Hawkins       .def_prop_ro("attributes",
3113436c6c9cSStella Laurenzo                    [](PyOperationBase &self) {
3114b56d1ec6SPeter Hawkins                      return PyOpAttributeMap(self.getOperation().getRef());
3115436c6c9cSStella Laurenzo                    })
3116b56d1ec6SPeter Hawkins       .def_prop_ro(
311733df617dSStella Laurenzo           "context",
311833df617dSStella Laurenzo           [](PyOperationBase &self) {
311933df617dSStella Laurenzo             PyOperation &concreteOperation = self.getOperation();
312033df617dSStella Laurenzo             concreteOperation.checkValid();
312133df617dSStella Laurenzo             return concreteOperation.getContext().getObject();
312233df617dSStella Laurenzo           },
312333df617dSStella Laurenzo           "Context that owns the Operation")
3124b56d1ec6SPeter Hawkins       .def_prop_ro("name",
312533df617dSStella Laurenzo                    [](PyOperationBase &self) {
312633df617dSStella Laurenzo                      auto &concreteOperation = self.getOperation();
312733df617dSStella Laurenzo                      concreteOperation.checkValid();
3128b56d1ec6SPeter Hawkins                      MlirOperation operation = concreteOperation.get();
3129b56d1ec6SPeter Hawkins                      MlirStringRef name =
3130b56d1ec6SPeter Hawkins                          mlirIdentifierStr(mlirOperationGetName(operation));
3131b56d1ec6SPeter Hawkins                      return nb::str(name.data, name.length);
313233df617dSStella Laurenzo                    })
3133b56d1ec6SPeter Hawkins       .def_prop_ro("operands",
3134436c6c9cSStella Laurenzo                    [](PyOperationBase &self) {
3135b56d1ec6SPeter Hawkins                      return PyOpOperandList(self.getOperation().getRef());
3136436c6c9cSStella Laurenzo                    })
3137b56d1ec6SPeter Hawkins       .def_prop_ro("regions",
3138436c6c9cSStella Laurenzo                    [](PyOperationBase &self) {
3139b56d1ec6SPeter Hawkins                      return PyRegionList(self.getOperation().getRef());
3140436c6c9cSStella Laurenzo                    })
3141b56d1ec6SPeter Hawkins       .def_prop_ro(
3142436c6c9cSStella Laurenzo           "results",
3143436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
3144436c6c9cSStella Laurenzo             return PyOpResultList(self.getOperation().getRef());
3145436c6c9cSStella Laurenzo           },
3146436c6c9cSStella Laurenzo           "Returns the list of Operation results.")
3147b56d1ec6SPeter Hawkins       .def_prop_ro(
3148436c6c9cSStella Laurenzo           "result",
3149436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
3150436c6c9cSStella Laurenzo             auto &operation = self.getOperation();
3151*acde3f72SPeter Hawkins             return PyOpResult(operation.getRef(), getUniqueResult(operation))
31527c850867SMaksim Levental                 .maybeDownCast();
3153436c6c9cSStella Laurenzo           },
3154436c6c9cSStella Laurenzo           "Shortcut to get an op result if it has only one (throws an error "
3155436c6c9cSStella Laurenzo           "otherwise).")
3156b56d1ec6SPeter Hawkins       .def_prop_ro(
3157d5429a13Srkayaith           "location",
3158d5429a13Srkayaith           [](PyOperationBase &self) {
3159d5429a13Srkayaith             PyOperation &operation = self.getOperation();
3160d5429a13Srkayaith             return PyLocation(operation.getContext(),
3161d5429a13Srkayaith                               mlirOperationGetLocation(operation.get()));
3162d5429a13Srkayaith           },
3163d5429a13Srkayaith           "Returns the source location the operation was defined or derived "
3164d5429a13Srkayaith           "from.")
3165b56d1ec6SPeter Hawkins       .def_prop_ro("parent",
3166b56d1ec6SPeter Hawkins                    [](PyOperationBase &self) -> nb::object {
3167b56d1ec6SPeter Hawkins                      auto parent = self.getOperation().getParentOperation();
316833df617dSStella Laurenzo                      if (parent)
316933df617dSStella Laurenzo                        return parent->getObject();
3170b56d1ec6SPeter Hawkins                      return nb::none();
317133df617dSStella Laurenzo                    })
3172436c6c9cSStella Laurenzo       .def(
3173436c6c9cSStella Laurenzo           "__str__",
3174436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
3175436c6c9cSStella Laurenzo             return self.getAsm(/*binary=*/false,
3176e823ababSKazu Hirata                                /*largeElementsLimit=*/std::nullopt,
3177436c6c9cSStella Laurenzo                                /*enableDebugInfo=*/false,
3178436c6c9cSStella Laurenzo                                /*prettyDebugInfo=*/false,
3179436c6c9cSStella Laurenzo                                /*printGenericOpForm=*/false,
3180ace1d0adSStella Laurenzo                                /*useLocalScope=*/false,
3181abad8455SJonas Rickert                                /*assumeVerified=*/false,
3182abad8455SJonas Rickert                                /*skipRegions=*/false);
3183436c6c9cSStella Laurenzo           },
3184436c6c9cSStella Laurenzo           "Returns the assembly form of the operation.")
3185204acc5cSJacques Pienaar       .def("print",
3186b56d1ec6SPeter Hawkins            nb::overload_cast<PyAsmState &, nb::object, bool>(
3187204acc5cSJacques Pienaar                &PyOperationBase::print),
3188b56d1ec6SPeter Hawkins            nb::arg("state"), nb::arg("file").none() = nb::none(),
3189b56d1ec6SPeter Hawkins            nb::arg("binary") = false, kOperationPrintStateDocstring)
3190204acc5cSJacques Pienaar       .def("print",
3191b56d1ec6SPeter Hawkins            nb::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
3192b56d1ec6SPeter Hawkins                              bool, nb::object, bool, bool>(
3193abad8455SJonas Rickert                &PyOperationBase::print),
3194436c6c9cSStella Laurenzo            // Careful: Lots of arguments must match up with print method.
3195b56d1ec6SPeter Hawkins            nb::arg("large_elements_limit").none() = nb::none(),
3196b56d1ec6SPeter Hawkins            nb::arg("enable_debug_info") = false,
3197b56d1ec6SPeter Hawkins            nb::arg("pretty_debug_info") = false,
3198b56d1ec6SPeter Hawkins            nb::arg("print_generic_op_form") = false,
3199b56d1ec6SPeter Hawkins            nb::arg("use_local_scope") = false,
3200b56d1ec6SPeter Hawkins            nb::arg("assume_verified") = false,
3201b56d1ec6SPeter Hawkins            nb::arg("file").none() = nb::none(), nb::arg("binary") = false,
3202b56d1ec6SPeter Hawkins            nb::arg("skip_regions") = false, kOperationPrintDocstring)
3203b56d1ec6SPeter Hawkins       .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
3204b56d1ec6SPeter Hawkins            nb::arg("desired_version").none() = nb::none(),
320589418ddcSMehdi Amini            kOperationPrintBytecodeDocstring)
3206436c6c9cSStella Laurenzo       .def("get_asm", &PyOperationBase::getAsm,
3207436c6c9cSStella Laurenzo            // Careful: Lots of arguments must match up with get_asm method.
3208b56d1ec6SPeter Hawkins            nb::arg("binary") = false,
3209b56d1ec6SPeter Hawkins            nb::arg("large_elements_limit").none() = nb::none(),
3210b56d1ec6SPeter Hawkins            nb::arg("enable_debug_info") = false,
3211b56d1ec6SPeter Hawkins            nb::arg("pretty_debug_info") = false,
3212b56d1ec6SPeter Hawkins            nb::arg("print_generic_op_form") = false,
3213b56d1ec6SPeter Hawkins            nb::arg("use_local_scope") = false,
3214b56d1ec6SPeter Hawkins            nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
3215abad8455SJonas Rickert            kOperationGetAsmDocstring)
32163ea4c501SRahul Kayaith       .def("verify", &PyOperationBase::verify,
32173ea4c501SRahul Kayaith            "Verify the operation. Raises MLIRError if verification fails, and "
32183ea4c501SRahul Kayaith            "returns true otherwise.")
3219b56d1ec6SPeter Hawkins       .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
322024685aaeSAlex Zinenko            "Puts self immediately after the other operation in its parent "
322124685aaeSAlex Zinenko            "block.")
3222b56d1ec6SPeter Hawkins       .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
322324685aaeSAlex Zinenko            "Puts self immediately before the other operation in its parent "
322424685aaeSAlex Zinenko            "block.")
322524685aaeSAlex Zinenko       .def(
322633df617dSStella Laurenzo           "clone",
3227b56d1ec6SPeter Hawkins           [](PyOperationBase &self, nb::object ip) {
322833df617dSStella Laurenzo             return self.getOperation().clone(ip);
322933df617dSStella Laurenzo           },
3230b56d1ec6SPeter Hawkins           nb::arg("ip").none() = nb::none())
323133df617dSStella Laurenzo       .def(
323224685aaeSAlex Zinenko           "detach_from_parent",
323324685aaeSAlex Zinenko           [](PyOperationBase &self) {
323424685aaeSAlex Zinenko             PyOperation &operation = self.getOperation();
323524685aaeSAlex Zinenko             operation.checkValid();
323624685aaeSAlex Zinenko             if (!operation.isAttached())
3237b56d1ec6SPeter Hawkins               throw nb::value_error("Detached operation has no parent.");
323824685aaeSAlex Zinenko 
323924685aaeSAlex Zinenko             operation.detachFromParent();
324024685aaeSAlex Zinenko             return operation.createOpView();
324124685aaeSAlex Zinenko           },
324233df617dSStella Laurenzo           "Detaches the operation from its parent block.")
324347148832SHideto Ueno       .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3244b56d1ec6SPeter Hawkins       .def("walk", &PyOperationBase::walk, nb::arg("callback"),
3245b56d1ec6SPeter Hawkins            nb::arg("walk_order") = MlirWalkPostOrder);
3246436c6c9cSStella Laurenzo 
3247b56d1ec6SPeter Hawkins   nb::class_<PyOperation, PyOperationBase>(m, "Operation")
3248*acde3f72SPeter Hawkins       .def_static(
3249*acde3f72SPeter Hawkins           "create",
3250*acde3f72SPeter Hawkins           [](std::string_view name,
3251*acde3f72SPeter Hawkins              std::optional<std::vector<PyType *>> results,
3252*acde3f72SPeter Hawkins              std::optional<std::vector<PyValue *>> operands,
3253*acde3f72SPeter Hawkins              std::optional<nb::dict> attributes,
3254*acde3f72SPeter Hawkins              std::optional<std::vector<PyBlock *>> successors, int regions,
3255*acde3f72SPeter Hawkins              DefaultingPyLocation location, const nb::object &maybeIp,
3256*acde3f72SPeter Hawkins              bool inferType) {
3257*acde3f72SPeter Hawkins             // Unpack/validate operands.
3258*acde3f72SPeter Hawkins             llvm::SmallVector<MlirValue, 4> mlirOperands;
3259*acde3f72SPeter Hawkins             if (operands) {
3260*acde3f72SPeter Hawkins               mlirOperands.reserve(operands->size());
3261*acde3f72SPeter Hawkins               for (PyValue *operand : *operands) {
3262*acde3f72SPeter Hawkins                 if (!operand)
3263*acde3f72SPeter Hawkins                   throw nb::value_error("operand value cannot be None");
3264*acde3f72SPeter Hawkins                 mlirOperands.push_back(operand->get());
3265*acde3f72SPeter Hawkins               }
3266*acde3f72SPeter Hawkins             }
3267*acde3f72SPeter Hawkins 
3268*acde3f72SPeter Hawkins             return PyOperation::create(name, results, mlirOperands, attributes,
3269*acde3f72SPeter Hawkins                                        successors, regions, location, maybeIp,
3270*acde3f72SPeter Hawkins                                        inferType);
3271*acde3f72SPeter Hawkins           },
3272*acde3f72SPeter Hawkins           nb::arg("name"), nb::arg("results").none() = nb::none(),
3273b56d1ec6SPeter Hawkins           nb::arg("operands").none() = nb::none(),
3274b56d1ec6SPeter Hawkins           nb::arg("attributes").none() = nb::none(),
3275*acde3f72SPeter Hawkins           nb::arg("successors").none() = nb::none(), nb::arg("regions") = 0,
3276*acde3f72SPeter Hawkins           nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
3277b56d1ec6SPeter Hawkins           nb::arg("infer_type") = false, kOperationCreateDocstring)
327837107e17Srkayaith       .def_static(
327937107e17Srkayaith           "parse",
328037107e17Srkayaith           [](const std::string &sourceStr, const std::string &sourceName,
328137107e17Srkayaith              DefaultingPyMlirContext context) {
328237107e17Srkayaith             return PyOperation::parse(context->getRef(), sourceStr, sourceName)
328337107e17Srkayaith                 ->createOpView();
328437107e17Srkayaith           },
3285b56d1ec6SPeter Hawkins           nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
3286b56d1ec6SPeter Hawkins           nb::arg("context").none() = nb::none(),
328737107e17Srkayaith           "Parses an operation. Supports both text assembly format and binary "
328837107e17Srkayaith           "bytecode format.")
3289b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule)
32900126e906SJohn Demme       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
3291b56d1ec6SPeter Hawkins       .def_prop_ro("operation", [](nb::object self) { return self; })
3292b56d1ec6SPeter Hawkins       .def_prop_ro("opview", &PyOperation::createOpView)
3293b56d1ec6SPeter Hawkins       .def_prop_ro(
3294d7e49736SMaksim Levental           "successors",
3295d7e49736SMaksim Levental           [](PyOperationBase &self) {
3296d7e49736SMaksim Levental             return PyOpSuccessors(self.getOperation().getRef());
3297d7e49736SMaksim Levental           },
3298d7e49736SMaksim Levental           "Returns the list of Operation successors.");
3299436c6c9cSStella Laurenzo 
3300436c6c9cSStella Laurenzo   auto opViewClass =
3301b56d1ec6SPeter Hawkins       nb::class_<PyOpView, PyOperationBase>(m, "OpView")
3302b56d1ec6SPeter Hawkins           .def(nb::init<nb::object>(), nb::arg("operation"))
3303f4125e02SPeter Hawkins           .def(
3304f4125e02SPeter Hawkins               "__init__",
3305f4125e02SPeter Hawkins               [](PyOpView *self, std::string_view name,
3306f4125e02SPeter Hawkins                  std::tuple<int, bool> opRegionSpec,
3307f4125e02SPeter Hawkins                  nb::object operandSegmentSpecObj,
3308f4125e02SPeter Hawkins                  nb::object resultSegmentSpecObj,
3309f4125e02SPeter Hawkins                  std::optional<nb::list> resultTypeList, nb::list operandList,
3310f4125e02SPeter Hawkins                  std::optional<nb::dict> attributes,
3311f4125e02SPeter Hawkins                  std::optional<std::vector<PyBlock *>> successors,
3312f4125e02SPeter Hawkins                  std::optional<int> regions, DefaultingPyLocation location,
3313f4125e02SPeter Hawkins                  const nb::object &maybeIp) {
3314f4125e02SPeter Hawkins                 new (self) PyOpView(PyOpView::buildGeneric(
3315f4125e02SPeter Hawkins                     name, opRegionSpec, operandSegmentSpecObj,
3316f4125e02SPeter Hawkins                     resultSegmentSpecObj, resultTypeList, operandList,
3317f4125e02SPeter Hawkins                     attributes, successors, regions, location, maybeIp));
3318f4125e02SPeter Hawkins               },
3319f4125e02SPeter Hawkins               nb::arg("name"), nb::arg("opRegionSpec"),
3320f4125e02SPeter Hawkins               nb::arg("operandSegmentSpecObj").none() = nb::none(),
3321f4125e02SPeter Hawkins               nb::arg("resultSegmentSpecObj").none() = nb::none(),
3322f4125e02SPeter Hawkins               nb::arg("results").none() = nb::none(),
3323f4125e02SPeter Hawkins               nb::arg("operands").none() = nb::none(),
3324f4125e02SPeter Hawkins               nb::arg("attributes").none() = nb::none(),
3325f4125e02SPeter Hawkins               nb::arg("successors").none() = nb::none(),
3326f4125e02SPeter Hawkins               nb::arg("regions").none() = nb::none(),
3327f4125e02SPeter Hawkins               nb::arg("loc").none() = nb::none(),
3328f4125e02SPeter Hawkins               nb::arg("ip").none() = nb::none())
3329f4125e02SPeter Hawkins 
3330b56d1ec6SPeter Hawkins           .def_prop_ro("operation", &PyOpView::getOperationObject)
3331b56d1ec6SPeter Hawkins           .def_prop_ro("opview", [](nb::object self) { return self; })
3332d7e49736SMaksim Levental           .def(
3333d7e49736SMaksim Levental               "__str__",
3334b56d1ec6SPeter Hawkins               [](PyOpView &self) { return nb::str(self.getOperationObject()); })
3335b56d1ec6SPeter Hawkins           .def_prop_ro(
3336d7e49736SMaksim Levental               "successors",
3337d7e49736SMaksim Levental               [](PyOperationBase &self) {
3338d7e49736SMaksim Levental                 return PyOpSuccessors(self.getOperation().getRef());
3339d7e49736SMaksim Levental               },
3340d7e49736SMaksim Levental               "Returns the list of Operation successors.");
3341b56d1ec6SPeter Hawkins   opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
3342b56d1ec6SPeter Hawkins   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
3343b56d1ec6SPeter Hawkins   opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
3344f4125e02SPeter Hawkins   // It is faster to pass the operation_name, ods_regions, and
3345f4125e02SPeter Hawkins   // ods_operand_segments/ods_result_segments as arguments to the constructor,
3346f4125e02SPeter Hawkins   // rather than to access them as attributes.
3347436c6c9cSStella Laurenzo   opViewClass.attr("build_generic") = classmethod(
3348f4125e02SPeter Hawkins       [](nb::handle cls, std::optional<nb::list> resultTypeList,
3349f4125e02SPeter Hawkins          nb::list operandList, std::optional<nb::dict> attributes,
3350f4125e02SPeter Hawkins          std::optional<std::vector<PyBlock *>> successors,
3351f4125e02SPeter Hawkins          std::optional<int> regions, DefaultingPyLocation location,
3352f4125e02SPeter Hawkins          const nb::object &maybeIp) {
3353f4125e02SPeter Hawkins         std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3354f4125e02SPeter Hawkins         std::tuple<int, bool> opRegionSpec =
3355f4125e02SPeter Hawkins             nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
3356f4125e02SPeter Hawkins         nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
3357f4125e02SPeter Hawkins         nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
3358f4125e02SPeter Hawkins         return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
3359f4125e02SPeter Hawkins                                       resultSegmentSpec, resultTypeList,
3360f4125e02SPeter Hawkins                                       operandList, attributes, successors,
3361f4125e02SPeter Hawkins                                       regions, location, maybeIp);
3362f4125e02SPeter Hawkins       },
3363f4125e02SPeter Hawkins       nb::arg("cls"), nb::arg("results").none() = nb::none(),
3364b56d1ec6SPeter Hawkins       nb::arg("operands").none() = nb::none(),
3365b56d1ec6SPeter Hawkins       nb::arg("attributes").none() = nb::none(),
3366b56d1ec6SPeter Hawkins       nb::arg("successors").none() = nb::none(),
3367b56d1ec6SPeter Hawkins       nb::arg("regions").none() = nb::none(),
3368b56d1ec6SPeter Hawkins       nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
3369436c6c9cSStella Laurenzo       "Builds a specific, generated OpView based on class level attributes.");
337037107e17Srkayaith   opViewClass.attr("parse") = classmethod(
3371b56d1ec6SPeter Hawkins       [](const nb::object &cls, const std::string &sourceStr,
337237107e17Srkayaith          const std::string &sourceName, DefaultingPyMlirContext context) {
337337107e17Srkayaith         PyOperationRef parsed =
337437107e17Srkayaith             PyOperation::parse(context->getRef(), sourceStr, sourceName);
337537107e17Srkayaith 
337637107e17Srkayaith         // Check if the expected operation was parsed, and cast to to the
337737107e17Srkayaith         // appropriate `OpView` subclass if successful.
337837107e17Srkayaith         // NOTE: This accesses attributes that have been automatically added to
337937107e17Srkayaith         // `OpView` subclasses, and is not intended to be used on `OpView`
338037107e17Srkayaith         // directly.
338137107e17Srkayaith         std::string clsOpName =
3382b56d1ec6SPeter Hawkins             nb::cast<std::string>(cls.attr("OPERATION_NAME"));
33833ea4c501SRahul Kayaith         MlirStringRef identifier =
338437107e17Srkayaith             mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
33853ea4c501SRahul Kayaith         std::string_view parsedOpName(identifier.data, identifier.length);
33863ea4c501SRahul Kayaith         if (clsOpName != parsedOpName)
33873ea4c501SRahul Kayaith           throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
33883ea4c501SRahul Kayaith                           parsedOpName + "'");
3389b56d1ec6SPeter Hawkins         return PyOpView::constructDerived(cls, parsed.getObject());
339037107e17Srkayaith       },
3391b56d1ec6SPeter Hawkins       nb::arg("cls"), nb::arg("source"), nb::kw_only(),
3392b56d1ec6SPeter Hawkins       nb::arg("source_name") = "", nb::arg("context").none() = nb::none(),
339337107e17Srkayaith       "Parses a specific, generated OpView based on class level attributes");
3394436c6c9cSStella Laurenzo 
3395436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3396436c6c9cSStella Laurenzo   // Mapping of PyRegion.
3397436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3398b56d1ec6SPeter Hawkins   nb::class_<PyRegion>(m, "Region")
3399b56d1ec6SPeter Hawkins       .def_prop_ro(
3400436c6c9cSStella Laurenzo           "blocks",
3401436c6c9cSStella Laurenzo           [](PyRegion &self) {
3402436c6c9cSStella Laurenzo             return PyBlockList(self.getParentOperation(), self.get());
3403436c6c9cSStella Laurenzo           },
3404436c6c9cSStella Laurenzo           "Returns a forward-optimized sequence of blocks.")
3405b56d1ec6SPeter Hawkins       .def_prop_ro(
340678f2dae0SAlex Zinenko           "owner",
340778f2dae0SAlex Zinenko           [](PyRegion &self) {
340878f2dae0SAlex Zinenko             return self.getParentOperation()->createOpView();
340978f2dae0SAlex Zinenko           },
341078f2dae0SAlex Zinenko           "Returns the operation owning this region.")
3411436c6c9cSStella Laurenzo       .def(
3412436c6c9cSStella Laurenzo           "__iter__",
3413436c6c9cSStella Laurenzo           [](PyRegion &self) {
3414436c6c9cSStella Laurenzo             self.checkValid();
3415436c6c9cSStella Laurenzo             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3416436c6c9cSStella Laurenzo             return PyBlockIterator(self.getParentOperation(), firstBlock);
3417436c6c9cSStella Laurenzo           },
3418436c6c9cSStella Laurenzo           "Iterates over blocks in the region.")
3419436c6c9cSStella Laurenzo       .def("__eq__",
3420436c6c9cSStella Laurenzo            [](PyRegion &self, PyRegion &other) {
3421436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
3422436c6c9cSStella Laurenzo            })
3423b56d1ec6SPeter Hawkins       .def("__eq__", [](PyRegion &self, nb::object &other) { return false; });
3424436c6c9cSStella Laurenzo 
3425436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3426436c6c9cSStella Laurenzo   // Mapping of PyBlock.
3427436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3428b56d1ec6SPeter Hawkins   nb::class_<PyBlock>(m, "Block")
3429b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
3430b56d1ec6SPeter Hawkins       .def_prop_ro(
343196fbd5cdSJohn Demme           "owner",
343296fbd5cdSJohn Demme           [](PyBlock &self) {
343396fbd5cdSJohn Demme             return self.getParentOperation()->createOpView();
343496fbd5cdSJohn Demme           },
343596fbd5cdSJohn Demme           "Returns the owning operation of this block.")
3436b56d1ec6SPeter Hawkins       .def_prop_ro(
34378e6c55c9SStella Laurenzo           "region",
34388e6c55c9SStella Laurenzo           [](PyBlock &self) {
34398e6c55c9SStella Laurenzo             MlirRegion region = mlirBlockGetParentRegion(self.get());
34408e6c55c9SStella Laurenzo             return PyRegion(self.getParentOperation(), region);
34418e6c55c9SStella Laurenzo           },
34428e6c55c9SStella Laurenzo           "Returns the owning region of this block.")
3443b56d1ec6SPeter Hawkins       .def_prop_ro(
3444436c6c9cSStella Laurenzo           "arguments",
3445436c6c9cSStella Laurenzo           [](PyBlock &self) {
3446436c6c9cSStella Laurenzo             return PyBlockArgumentList(self.getParentOperation(), self.get());
3447436c6c9cSStella Laurenzo           },
3448436c6c9cSStella Laurenzo           "Returns a list of block arguments.")
344955d2fffdSSandeep Dasgupta       .def(
345055d2fffdSSandeep Dasgupta           "add_argument",
345155d2fffdSSandeep Dasgupta           [](PyBlock &self, const PyType &type, const PyLocation &loc) {
345255d2fffdSSandeep Dasgupta             return mlirBlockAddArgument(self.get(), type, loc);
345355d2fffdSSandeep Dasgupta           },
345455d2fffdSSandeep Dasgupta           "Append an argument of the specified type to the block and returns "
345555d2fffdSSandeep Dasgupta           "the newly added argument.")
345655d2fffdSSandeep Dasgupta       .def(
345755d2fffdSSandeep Dasgupta           "erase_argument",
345855d2fffdSSandeep Dasgupta           [](PyBlock &self, unsigned index) {
345955d2fffdSSandeep Dasgupta             return mlirBlockEraseArgument(self.get(), index);
346055d2fffdSSandeep Dasgupta           },
346155d2fffdSSandeep Dasgupta           "Erase the argument at 'index' and remove it from the argument list.")
3462b56d1ec6SPeter Hawkins       .def_prop_ro(
3463436c6c9cSStella Laurenzo           "operations",
3464436c6c9cSStella Laurenzo           [](PyBlock &self) {
3465436c6c9cSStella Laurenzo             return PyOperationList(self.getParentOperation(), self.get());
3466436c6c9cSStella Laurenzo           },
3467436c6c9cSStella Laurenzo           "Returns a forward-optimized sequence of operations.")
346878f2dae0SAlex Zinenko       .def_static(
346978f2dae0SAlex Zinenko           "create_at_start",
3470b56d1ec6SPeter Hawkins           [](PyRegion &parent, const nb::sequence &pyArgTypes,
3471b56d1ec6SPeter Hawkins              const std::optional<nb::sequence> &pyArgLocs) {
347278f2dae0SAlex Zinenko             parent.checkValid();
3473514dddbeSRahul Kayaith             MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
347478f2dae0SAlex Zinenko             mlirRegionInsertOwnedBlock(parent, 0, block);
347578f2dae0SAlex Zinenko             return PyBlock(parent.getParentOperation(), block);
347678f2dae0SAlex Zinenko           },
3477b56d1ec6SPeter Hawkins           nb::arg("parent"), nb::arg("arg_types") = nb::list(),
3478b56d1ec6SPeter Hawkins           nb::arg("arg_locs") = std::nullopt,
347978f2dae0SAlex Zinenko           "Creates and returns a new Block at the beginning of the given "
3480514dddbeSRahul Kayaith           "region (with given argument types and locations).")
3481436c6c9cSStella Laurenzo       .def(
34828d8738f6SJohn Demme           "append_to",
34838d8738f6SJohn Demme           [](PyBlock &self, PyRegion &region) {
34848d8738f6SJohn Demme             MlirBlock b = self.get();
34858d8738f6SJohn Demme             if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
34868d8738f6SJohn Demme               mlirBlockDetach(b);
34878d8738f6SJohn Demme             mlirRegionAppendOwnedBlock(region.get(), b);
34888d8738f6SJohn Demme           },
34898d8738f6SJohn Demme           "Append this block to a region, transferring ownership if necessary")
34908d8738f6SJohn Demme       .def(
34918e6c55c9SStella Laurenzo           "create_before",
3492b56d1ec6SPeter Hawkins           [](PyBlock &self, const nb::args &pyArgTypes,
3493b56d1ec6SPeter Hawkins              const std::optional<nb::sequence> &pyArgLocs) {
34948e6c55c9SStella Laurenzo             self.checkValid();
3495b56d1ec6SPeter Hawkins             MlirBlock block =
3496b56d1ec6SPeter Hawkins                 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
34978e6c55c9SStella Laurenzo             MlirRegion region = mlirBlockGetParentRegion(self.get());
34988e6c55c9SStella Laurenzo             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
34998e6c55c9SStella Laurenzo             return PyBlock(self.getParentOperation(), block);
35008e6c55c9SStella Laurenzo           },
3501b56d1ec6SPeter Hawkins           nb::arg("arg_types"), nb::kw_only(),
3502b56d1ec6SPeter Hawkins           nb::arg("arg_locs") = std::nullopt,
35038e6c55c9SStella Laurenzo           "Creates and returns a new Block before this block "
3504514dddbeSRahul Kayaith           "(with given argument types and locations).")
35058e6c55c9SStella Laurenzo       .def(
35068e6c55c9SStella Laurenzo           "create_after",
3507b56d1ec6SPeter Hawkins           [](PyBlock &self, const nb::args &pyArgTypes,
3508b56d1ec6SPeter Hawkins              const std::optional<nb::sequence> &pyArgLocs) {
35098e6c55c9SStella Laurenzo             self.checkValid();
3510b56d1ec6SPeter Hawkins             MlirBlock block =
3511b56d1ec6SPeter Hawkins                 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
35128e6c55c9SStella Laurenzo             MlirRegion region = mlirBlockGetParentRegion(self.get());
35138e6c55c9SStella Laurenzo             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
35148e6c55c9SStella Laurenzo             return PyBlock(self.getParentOperation(), block);
35158e6c55c9SStella Laurenzo           },
3516b56d1ec6SPeter Hawkins           nb::arg("arg_types"), nb::kw_only(),
3517b56d1ec6SPeter Hawkins           nb::arg("arg_locs") = std::nullopt,
35188e6c55c9SStella Laurenzo           "Creates and returns a new Block after this block "
3519514dddbeSRahul Kayaith           "(with given argument types and locations).")
35208e6c55c9SStella Laurenzo       .def(
3521436c6c9cSStella Laurenzo           "__iter__",
3522436c6c9cSStella Laurenzo           [](PyBlock &self) {
3523436c6c9cSStella Laurenzo             self.checkValid();
3524436c6c9cSStella Laurenzo             MlirOperation firstOperation =
3525436c6c9cSStella Laurenzo                 mlirBlockGetFirstOperation(self.get());
3526436c6c9cSStella Laurenzo             return PyOperationIterator(self.getParentOperation(),
3527436c6c9cSStella Laurenzo                                        firstOperation);
3528436c6c9cSStella Laurenzo           },
3529436c6c9cSStella Laurenzo           "Iterates over operations in the block.")
3530436c6c9cSStella Laurenzo       .def("__eq__",
3531436c6c9cSStella Laurenzo            [](PyBlock &self, PyBlock &other) {
3532436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
3533436c6c9cSStella Laurenzo            })
3534b56d1ec6SPeter Hawkins       .def("__eq__", [](PyBlock &self, nb::object &other) { return false; })
3535fa45b2fbSMike Urbach       .def("__hash__",
3536fa45b2fbSMike Urbach            [](PyBlock &self) {
3537fa45b2fbSMike Urbach              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3538fa45b2fbSMike Urbach            })
3539436c6c9cSStella Laurenzo       .def(
3540436c6c9cSStella Laurenzo           "__str__",
3541436c6c9cSStella Laurenzo           [](PyBlock &self) {
3542436c6c9cSStella Laurenzo             self.checkValid();
3543436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
3544436c6c9cSStella Laurenzo             mlirBlockPrint(self.get(), printAccum.getCallback(),
3545436c6c9cSStella Laurenzo                            printAccum.getUserData());
3546436c6c9cSStella Laurenzo             return printAccum.join();
3547436c6c9cSStella Laurenzo           },
354824685aaeSAlex Zinenko           "Returns the assembly form of the block.")
354924685aaeSAlex Zinenko       .def(
355024685aaeSAlex Zinenko           "append",
355124685aaeSAlex Zinenko           [](PyBlock &self, PyOperationBase &operation) {
355224685aaeSAlex Zinenko             if (operation.getOperation().isAttached())
355324685aaeSAlex Zinenko               operation.getOperation().detachFromParent();
355424685aaeSAlex Zinenko 
355524685aaeSAlex Zinenko             MlirOperation mlirOperation = operation.getOperation().get();
355624685aaeSAlex Zinenko             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
355724685aaeSAlex Zinenko             operation.getOperation().setAttached(
355824685aaeSAlex Zinenko                 self.getParentOperation().getObject());
355924685aaeSAlex Zinenko           },
3560b56d1ec6SPeter Hawkins           nb::arg("operation"),
356124685aaeSAlex Zinenko           "Appends an operation to this block. If the operation is currently "
356224685aaeSAlex Zinenko           "in another block, it will be moved.");
3563436c6c9cSStella Laurenzo 
3564436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3565436c6c9cSStella Laurenzo   // Mapping of PyInsertionPoint.
3566436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3567436c6c9cSStella Laurenzo 
3568b56d1ec6SPeter Hawkins   nb::class_<PyInsertionPoint>(m, "InsertionPoint")
3569b56d1ec6SPeter Hawkins       .def(nb::init<PyBlock &>(), nb::arg("block"),
3570436c6c9cSStella Laurenzo            "Inserts after the last operation but still inside the block.")
3571436c6c9cSStella Laurenzo       .def("__enter__", &PyInsertionPoint::contextEnter)
3572b56d1ec6SPeter Hawkins       .def("__exit__", &PyInsertionPoint::contextExit,
3573b56d1ec6SPeter Hawkins            nb::arg("exc_type").none(), nb::arg("exc_value").none(),
3574b56d1ec6SPeter Hawkins            nb::arg("traceback").none())
3575b56d1ec6SPeter Hawkins       .def_prop_ro_static(
3576436c6c9cSStella Laurenzo           "current",
3577b56d1ec6SPeter Hawkins           [](nb::object & /*class*/) {
3578436c6c9cSStella Laurenzo             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3579436c6c9cSStella Laurenzo             if (!ip)
3580b56d1ec6SPeter Hawkins               throw nb::value_error("No current InsertionPoint");
3581436c6c9cSStella Laurenzo             return ip;
3582436c6c9cSStella Laurenzo           },
3583436c6c9cSStella Laurenzo           "Gets the InsertionPoint bound to the current thread or raises "
3584436c6c9cSStella Laurenzo           "ValueError if none has been set")
3585b56d1ec6SPeter Hawkins       .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
3586436c6c9cSStella Laurenzo            "Inserts before a referenced operation.")
3587436c6c9cSStella Laurenzo       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
3588b56d1ec6SPeter Hawkins                   nb::arg("block"), "Inserts at the beginning of the block.")
3589436c6c9cSStella Laurenzo       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
3590b56d1ec6SPeter Hawkins                   nb::arg("block"), "Inserts before the block terminator.")
3591b56d1ec6SPeter Hawkins       .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
35928e6c55c9SStella Laurenzo            "Inserts an operation.")
3593b56d1ec6SPeter Hawkins       .def_prop_ro(
35948e6c55c9SStella Laurenzo           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
35955a600c23STomás Longeri           "Returns the block that this InsertionPoint points to.")
3596b56d1ec6SPeter Hawkins       .def_prop_ro(
35975a600c23STomás Longeri           "ref_operation",
3598b56d1ec6SPeter Hawkins           [](PyInsertionPoint &self) -> nb::object {
3599b1d682e0SMehdi Amini             auto refOperation = self.getRefOperation();
3600b1d682e0SMehdi Amini             if (refOperation)
3601b1d682e0SMehdi Amini               return refOperation->getObject();
3602b56d1ec6SPeter Hawkins             return nb::none();
36035a600c23STomás Longeri           },
36045a600c23STomás Longeri           "The reference operation before which new operations are "
36055a600c23STomás Longeri           "inserted, or None if the insertion point is at the end of "
36065a600c23STomás Longeri           "the block");
3607436c6c9cSStella Laurenzo 
3608436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3609436c6c9cSStella Laurenzo   // Mapping of PyAttribute.
3610436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3611b56d1ec6SPeter Hawkins   nb::class_<PyAttribute>(m, "Attribute")
3612b57d6fe4SStella Laurenzo       // Delegate to the PyAttribute copy constructor, which will also lifetime
3613b57d6fe4SStella Laurenzo       // extend the backing context which owns the MlirAttribute.
3614b56d1ec6SPeter Hawkins       .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
3615b57d6fe4SStella Laurenzo            "Casts the passed attribute to the generic Attribute")
3616b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule)
3617436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
3618436c6c9cSStella Laurenzo       .def_static(
3619436c6c9cSStella Laurenzo           "parse",
36204eee9ef9Smax           [](const std::string &attrSpec, DefaultingPyMlirContext context) {
36213ea4c501SRahul Kayaith             PyMlirContext::ErrorCapture errors(context->getRef());
36224eee9ef9Smax             MlirAttribute attr = mlirAttributeParseGet(
3623436c6c9cSStella Laurenzo                 context->get(), toMlirStringRef(attrSpec));
36244eee9ef9Smax             if (mlirAttributeIsNull(attr))
36253ea4c501SRahul Kayaith               throw MLIRError("Unable to parse attribute", errors.take());
36264eee9ef9Smax             return attr;
3627436c6c9cSStella Laurenzo           },
3628b56d1ec6SPeter Hawkins           nb::arg("asm"), nb::arg("context").none() = nb::none(),
36293ea4c501SRahul Kayaith           "Parses an attribute from an assembly form. Raises an MLIRError on "
36303ea4c501SRahul Kayaith           "failure.")
3631b56d1ec6SPeter Hawkins       .def_prop_ro(
3632436c6c9cSStella Laurenzo           "context",
3633436c6c9cSStella Laurenzo           [](PyAttribute &self) { return self.getContext().getObject(); },
3634436c6c9cSStella Laurenzo           "Context that owns the Attribute")
3635b56d1ec6SPeter Hawkins       .def_prop_ro("type",
3636b56d1ec6SPeter Hawkins                    [](PyAttribute &self) { return mlirAttributeGetType(self); })
3637436c6c9cSStella Laurenzo       .def(
3638436c6c9cSStella Laurenzo           "get_named",
3639436c6c9cSStella Laurenzo           [](PyAttribute &self, std::string name) {
3640436c6c9cSStella Laurenzo             return PyNamedAttribute(self, std::move(name));
3641436c6c9cSStella Laurenzo           },
3642b56d1ec6SPeter Hawkins           nb::keep_alive<0, 1>(), "Binds a name to the attribute")
3643436c6c9cSStella Laurenzo       .def("__eq__",
3644436c6c9cSStella Laurenzo            [](PyAttribute &self, PyAttribute &other) { return self == other; })
3645b56d1ec6SPeter Hawkins       .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; })
3646f78fe0b7Srkayaith       .def("__hash__",
3647f78fe0b7Srkayaith            [](PyAttribute &self) {
3648f78fe0b7Srkayaith              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3649f78fe0b7Srkayaith            })
3650436c6c9cSStella Laurenzo       .def(
3651436c6c9cSStella Laurenzo           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3652436c6c9cSStella Laurenzo           kDumpDocstring)
3653436c6c9cSStella Laurenzo       .def(
3654436c6c9cSStella Laurenzo           "__str__",
3655436c6c9cSStella Laurenzo           [](PyAttribute &self) {
3656436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
3657436c6c9cSStella Laurenzo             mlirAttributePrint(self, printAccum.getCallback(),
3658436c6c9cSStella Laurenzo                                printAccum.getUserData());
3659436c6c9cSStella Laurenzo             return printAccum.join();
3660436c6c9cSStella Laurenzo           },
3661436c6c9cSStella Laurenzo           "Returns the assembly form of the Attribute.")
36629566ee28Smax       .def("__repr__",
36639566ee28Smax            [](PyAttribute &self) {
3664436c6c9cSStella Laurenzo              // Generally, assembly formats are not printed for __repr__ because
3665436c6c9cSStella Laurenzo              // this can cause exceptionally long debug output and exceptions.
36669566ee28Smax              // However, attribute values are generally considered useful and
36679566ee28Smax              // are printed. This may need to be re-evaluated if debug dumps end
36689566ee28Smax              // up being excessive.
3669436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
3670436c6c9cSStella Laurenzo              printAccum.parts.append("Attribute(");
3671436c6c9cSStella Laurenzo              mlirAttributePrint(self, printAccum.getCallback(),
3672436c6c9cSStella Laurenzo                                 printAccum.getUserData());
3673436c6c9cSStella Laurenzo              printAccum.parts.append(")");
3674436c6c9cSStella Laurenzo              return printAccum.join();
36759566ee28Smax            })
3676b56d1ec6SPeter Hawkins       .def_prop_ro("typeid",
36779566ee28Smax                    [](PyAttribute &self) -> MlirTypeID {
36789566ee28Smax                      MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
36799566ee28Smax                      assert(!mlirTypeIDIsNull(mlirTypeID) &&
36809566ee28Smax                             "mlirTypeID was expected to be non-null.");
36819566ee28Smax                      return mlirTypeID;
36829566ee28Smax                    })
36839566ee28Smax       .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) {
36849566ee28Smax         MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
36859566ee28Smax         assert(!mlirTypeIDIsNull(mlirTypeID) &&
36869566ee28Smax                "mlirTypeID was expected to be non-null.");
3687b56d1ec6SPeter Hawkins         std::optional<nb::callable> typeCaster =
36889566ee28Smax             PyGlobals::get().lookupTypeCaster(mlirTypeID,
36899566ee28Smax                                               mlirAttributeGetDialect(self));
36909566ee28Smax         if (!typeCaster)
3691b56d1ec6SPeter Hawkins           return nb::cast(self);
36929566ee28Smax         return typeCaster.value()(self);
3693436c6c9cSStella Laurenzo       });
3694436c6c9cSStella Laurenzo 
3695436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3696436c6c9cSStella Laurenzo   // Mapping of PyNamedAttribute
3697436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3698b56d1ec6SPeter Hawkins   nb::class_<PyNamedAttribute>(m, "NamedAttribute")
3699436c6c9cSStella Laurenzo       .def("__repr__",
3700436c6c9cSStella Laurenzo            [](PyNamedAttribute &self) {
3701436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
3702436c6c9cSStella Laurenzo              printAccum.parts.append("NamedAttribute(");
3703436c6c9cSStella Laurenzo              printAccum.parts.append(
3704b56d1ec6SPeter Hawkins                  nb::str(mlirIdentifierStr(self.namedAttr.name).data,
3705120591e1SRiver Riddle                          mlirIdentifierStr(self.namedAttr.name).length));
3706436c6c9cSStella Laurenzo              printAccum.parts.append("=");
3707436c6c9cSStella Laurenzo              mlirAttributePrint(self.namedAttr.attribute,
3708436c6c9cSStella Laurenzo                                 printAccum.getCallback(),
3709436c6c9cSStella Laurenzo                                 printAccum.getUserData());
3710436c6c9cSStella Laurenzo              printAccum.parts.append(")");
3711436c6c9cSStella Laurenzo              return printAccum.join();
3712436c6c9cSStella Laurenzo            })
3713b56d1ec6SPeter Hawkins       .def_prop_ro(
3714436c6c9cSStella Laurenzo           "name",
3715436c6c9cSStella Laurenzo           [](PyNamedAttribute &self) {
3716b56d1ec6SPeter Hawkins             return nb::str(mlirIdentifierStr(self.namedAttr.name).data,
3717436c6c9cSStella Laurenzo                            mlirIdentifierStr(self.namedAttr.name).length);
3718436c6c9cSStella Laurenzo           },
3719436c6c9cSStella Laurenzo           "The name of the NamedAttribute binding")
3720b56d1ec6SPeter Hawkins       .def_prop_ro(
3721436c6c9cSStella Laurenzo           "attr",
37229566ee28Smax           [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
3723b56d1ec6SPeter Hawkins           nb::keep_alive<0, 1>(),
3724436c6c9cSStella Laurenzo           "The underlying generic attribute of the NamedAttribute binding");
3725436c6c9cSStella Laurenzo 
3726436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3727436c6c9cSStella Laurenzo   // Mapping of PyType.
3728436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3729b56d1ec6SPeter Hawkins   nb::class_<PyType>(m, "Type")
3730b57d6fe4SStella Laurenzo       // Delegate to the PyType copy constructor, which will also lifetime
3731b57d6fe4SStella Laurenzo       // extend the backing context which owns the MlirType.
3732b56d1ec6SPeter Hawkins       .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
3733b57d6fe4SStella Laurenzo            "Casts the passed type to the generic Type")
3734b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3735436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3736436c6c9cSStella Laurenzo       .def_static(
3737436c6c9cSStella Laurenzo           "parse",
3738436c6c9cSStella Laurenzo           [](std::string typeSpec, DefaultingPyMlirContext context) {
37393ea4c501SRahul Kayaith             PyMlirContext::ErrorCapture errors(context->getRef());
3740436c6c9cSStella Laurenzo             MlirType type =
3741436c6c9cSStella Laurenzo                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
37423ea4c501SRahul Kayaith             if (mlirTypeIsNull(type))
37433ea4c501SRahul Kayaith               throw MLIRError("Unable to parse type", errors.take());
3744bfb1ba75Smax             return type;
3745436c6c9cSStella Laurenzo           },
3746b56d1ec6SPeter Hawkins           nb::arg("asm"), nb::arg("context").none() = nb::none(),
3747436c6c9cSStella Laurenzo           kContextParseTypeDocstring)
3748b56d1ec6SPeter Hawkins       .def_prop_ro(
3749436c6c9cSStella Laurenzo           "context", [](PyType &self) { return self.getContext().getObject(); },
3750436c6c9cSStella Laurenzo           "Context that owns the Type")
3751436c6c9cSStella Laurenzo       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3752b56d1ec6SPeter Hawkins       .def(
3753b56d1ec6SPeter Hawkins           "__eq__", [](PyType &self, nb::object &other) { return false; },
3754b56d1ec6SPeter Hawkins           nb::arg("other").none())
3755f78fe0b7Srkayaith       .def("__hash__",
3756f78fe0b7Srkayaith            [](PyType &self) {
3757f78fe0b7Srkayaith              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3758f78fe0b7Srkayaith            })
3759436c6c9cSStella Laurenzo       .def(
3760436c6c9cSStella Laurenzo           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3761436c6c9cSStella Laurenzo       .def(
3762436c6c9cSStella Laurenzo           "__str__",
3763436c6c9cSStella Laurenzo           [](PyType &self) {
3764436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
3765436c6c9cSStella Laurenzo             mlirTypePrint(self, printAccum.getCallback(),
3766436c6c9cSStella Laurenzo                           printAccum.getUserData());
3767436c6c9cSStella Laurenzo             return printAccum.join();
3768436c6c9cSStella Laurenzo           },
3769436c6c9cSStella Laurenzo           "Returns the assembly form of the type.")
3770d39a7844Smax       .def("__repr__",
3771d39a7844Smax            [](PyType &self) {
3772436c6c9cSStella Laurenzo              // Generally, assembly formats are not printed for __repr__ because
3773436c6c9cSStella Laurenzo              // this can cause exceptionally long debug output and exceptions.
3774436c6c9cSStella Laurenzo              // However, types are an exception as they typically have compact
3775436c6c9cSStella Laurenzo              // assembly forms and printing them is useful.
3776436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
3777436c6c9cSStella Laurenzo              printAccum.parts.append("Type(");
3778d39a7844Smax              mlirTypePrint(self, printAccum.getCallback(),
3779d39a7844Smax                            printAccum.getUserData());
3780436c6c9cSStella Laurenzo              printAccum.parts.append(")");
3781436c6c9cSStella Laurenzo              return printAccum.join();
3782d39a7844Smax            })
3783bfb1ba75Smax       .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
3784bfb1ba75Smax            [](PyType &self) {
3785bfb1ba75Smax              MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3786bfb1ba75Smax              assert(!mlirTypeIDIsNull(mlirTypeID) &&
3787bfb1ba75Smax                     "mlirTypeID was expected to be non-null.");
3788b56d1ec6SPeter Hawkins              std::optional<nb::callable> typeCaster =
3789bfb1ba75Smax                  PyGlobals::get().lookupTypeCaster(mlirTypeID,
3790bfb1ba75Smax                                                    mlirTypeGetDialect(self));
3791bfb1ba75Smax              if (!typeCaster)
3792b56d1ec6SPeter Hawkins                return nb::cast(self);
3793bfb1ba75Smax              return typeCaster.value()(self);
3794bfb1ba75Smax            })
3795b56d1ec6SPeter Hawkins       .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID {
3796d39a7844Smax         MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3797d39a7844Smax         if (!mlirTypeIDIsNull(mlirTypeID))
3798d39a7844Smax           return mlirTypeID;
3799b56d1ec6SPeter Hawkins         auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
3800b56d1ec6SPeter Hawkins         throw nb::value_error(
3801b56d1ec6SPeter Hawkins             (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
3802d39a7844Smax       });
3803d39a7844Smax 
3804d39a7844Smax   //----------------------------------------------------------------------------
3805d39a7844Smax   // Mapping of PyTypeID.
3806d39a7844Smax   //----------------------------------------------------------------------------
3807b56d1ec6SPeter Hawkins   nb::class_<PyTypeID>(m, "TypeID")
3808b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
3809d39a7844Smax       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
3810d39a7844Smax       // Note, this tests whether the underlying TypeIDs are the same,
3811d39a7844Smax       // not whether the wrapper MlirTypeIDs are the same, nor whether
3812d39a7844Smax       // the Python objects are the same (i.e., PyTypeID is a value type).
3813d39a7844Smax       .def("__eq__",
3814d39a7844Smax            [](PyTypeID &self, PyTypeID &other) { return self == other; })
3815d39a7844Smax       .def("__eq__",
3816b56d1ec6SPeter Hawkins            [](PyTypeID &self, const nb::object &other) { return false; })
3817d39a7844Smax       // Note, this gives the hash value of the underlying TypeID, not the
3818d39a7844Smax       // hash value of the Python object, nor the hash value of the
3819d39a7844Smax       // MlirTypeID wrapper.
3820d39a7844Smax       .def("__hash__", [](PyTypeID &self) {
3821d39a7844Smax         return static_cast<size_t>(mlirTypeIDHashValue(self));
3822436c6c9cSStella Laurenzo       });
3823436c6c9cSStella Laurenzo 
3824436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3825436c6c9cSStella Laurenzo   // Mapping of Value.
3826436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3827b56d1ec6SPeter Hawkins   nb::class_<PyValue>(m, "Value")
3828b56d1ec6SPeter Hawkins       .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"))
3829b56d1ec6SPeter Hawkins       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
38303f3d1c90SMike Urbach       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3831b56d1ec6SPeter Hawkins       .def_prop_ro(
3832436c6c9cSStella Laurenzo           "context",
3833436c6c9cSStella Laurenzo           [](PyValue &self) { return self.getParentOperation()->getContext(); },
3834436c6c9cSStella Laurenzo           "Context in which the value lives.")
3835436c6c9cSStella Laurenzo       .def(
3836436c6c9cSStella Laurenzo           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3837436c6c9cSStella Laurenzo           kDumpDocstring)
3838b56d1ec6SPeter Hawkins       .def_prop_ro(
38395664c5e2SJohn Demme           "owner",
3840b56d1ec6SPeter Hawkins           [](PyValue &self) -> nb::object {
3841d747a170SJohn Demme             MlirValue v = self.get();
3842d747a170SJohn Demme             if (mlirValueIsAOpResult(v)) {
3843d747a170SJohn Demme               assert(
3844d747a170SJohn Demme                   mlirOperationEqual(self.getParentOperation()->get(),
38455664c5e2SJohn Demme                                      mlirOpResultGetOwner(self.get())) &&
38465664c5e2SJohn Demme                   "expected the owner of the value in Python to match that in "
38475664c5e2SJohn Demme                   "the IR");
38485664c5e2SJohn Demme               return self.getParentOperation().getObject();
3849d747a170SJohn Demme             }
3850d747a170SJohn Demme 
3851d747a170SJohn Demme             if (mlirValueIsABlockArgument(v)) {
3852d747a170SJohn Demme               MlirBlock block = mlirBlockArgumentGetOwner(self.get());
3853b56d1ec6SPeter Hawkins               return nb::cast(PyBlock(self.getParentOperation(), block));
3854d747a170SJohn Demme             }
3855d747a170SJohn Demme 
3856d747a170SJohn Demme             assert(false && "Value must be a block argument or an op result");
3857b56d1ec6SPeter Hawkins             return nb::none();
38585664c5e2SJohn Demme           })
3859b56d1ec6SPeter Hawkins       .def_prop_ro("uses",
3860afb2ed80SMike Urbach                    [](PyValue &self) {
3861afb2ed80SMike Urbach                      return PyOpOperandIterator(
3862afb2ed80SMike Urbach                          mlirValueGetFirstUse(self.get()));
3863afb2ed80SMike Urbach                    })
3864436c6c9cSStella Laurenzo       .def("__eq__",
3865436c6c9cSStella Laurenzo            [](PyValue &self, PyValue &other) {
3866436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
3867436c6c9cSStella Laurenzo            })
3868b56d1ec6SPeter Hawkins       .def("__eq__", [](PyValue &self, nb::object other) { return false; })
3869f78fe0b7Srkayaith       .def("__hash__",
3870f78fe0b7Srkayaith            [](PyValue &self) {
3871f78fe0b7Srkayaith              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3872f78fe0b7Srkayaith            })
3873436c6c9cSStella Laurenzo       .def(
3874436c6c9cSStella Laurenzo           "__str__",
3875436c6c9cSStella Laurenzo           [](PyValue &self) {
3876436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
3877436c6c9cSStella Laurenzo             printAccum.parts.append("Value(");
3878436c6c9cSStella Laurenzo             mlirValuePrint(self.get(), printAccum.getCallback(),
3879436c6c9cSStella Laurenzo                            printAccum.getUserData());
3880436c6c9cSStella Laurenzo             printAccum.parts.append(")");
3881436c6c9cSStella Laurenzo             return printAccum.join();
3882436c6c9cSStella Laurenzo           },
3883436c6c9cSStella Laurenzo           kValueDunderStrDocstring)
388481233c70Smax       .def(
388581233c70Smax           "get_name",
3886f1dbfcc1SJacques Pienaar           [](PyValue &self, bool useLocalScope) {
388781233c70Smax             PyPrintAccumulator printAccum;
3888f1dbfcc1SJacques Pienaar             MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
3889f1dbfcc1SJacques Pienaar             if (useLocalScope)
389081233c70Smax               mlirOpPrintingFlagsUseLocalScope(flags);
3891f1dbfcc1SJacques Pienaar             MlirAsmState valueState =
3892f1dbfcc1SJacques Pienaar                 mlirAsmStateCreateForValue(self.get(), flags);
3893d7e49736SMaksim Levental             mlirValuePrintAsOperand(self.get(), valueState,
3894d7e49736SMaksim Levental                                     printAccum.getCallback(),
389581233c70Smax                                     printAccum.getUserData());
389681233c70Smax             mlirOpPrintingFlagsDestroy(flags);
389775453714SJacques Pienaar             mlirAsmStateDestroy(valueState);
389881233c70Smax             return printAccum.join();
389981233c70Smax           },
3900b56d1ec6SPeter Hawkins           nb::arg("use_local_scope") = false)
3901f1dbfcc1SJacques Pienaar       .def(
3902f1dbfcc1SJacques Pienaar           "get_name",
3903b56d1ec6SPeter Hawkins           [](PyValue &self, PyAsmState &state) {
3904f1dbfcc1SJacques Pienaar             PyPrintAccumulator printAccum;
3905b56d1ec6SPeter Hawkins             MlirAsmState valueState = state.get();
3906f1dbfcc1SJacques Pienaar             mlirValuePrintAsOperand(self.get(), valueState,
3907f1dbfcc1SJacques Pienaar                                     printAccum.getCallback(),
3908f1dbfcc1SJacques Pienaar                                     printAccum.getUserData());
3909f1dbfcc1SJacques Pienaar             return printAccum.join();
3910f1dbfcc1SJacques Pienaar           },
3911b56d1ec6SPeter Hawkins           nb::arg("state"), kGetNameAsOperand)
3912b56d1ec6SPeter Hawkins       .def_prop_ro("type",
3913b56d1ec6SPeter Hawkins                    [](PyValue &self) { return mlirValueGetType(self.get()); })
39145b303f21Smax       .def(
391525b8433bSmax           "set_type",
391625b8433bSmax           [](PyValue &self, const PyType &type) {
391725b8433bSmax             return mlirValueSetType(self.get(), type);
391825b8433bSmax           },
3919b56d1ec6SPeter Hawkins           nb::arg("type"))
392025b8433bSmax       .def(
39215b303f21Smax           "replace_all_uses_with",
39225b303f21Smax           [](PyValue &self, PyValue &with) {
39235b303f21Smax             mlirValueReplaceAllUsesOfWith(self.get(), with.get());
39245b303f21Smax           },
39257c850867SMaksim Levental           kValueReplaceAllUsesWithDocstring)
392621df3251SPerry Gibson       .def(
392721df3251SPerry Gibson           "replace_all_uses_except",
392821df3251SPerry Gibson           [](MlirValue self, MlirValue with, PyOperation &exception) {
392921df3251SPerry Gibson             MlirOperation exceptedUser = exception.get();
393021df3251SPerry Gibson             mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
393121df3251SPerry Gibson           },
3932b56d1ec6SPeter Hawkins           nb::arg("with"), nb::arg("exceptions"),
393321df3251SPerry Gibson           kValueReplaceAllUsesExceptDocstring)
393421df3251SPerry Gibson       .def(
393521df3251SPerry Gibson           "replace_all_uses_except",
3936b56d1ec6SPeter Hawkins           [](MlirValue self, MlirValue with, nb::list exceptions) {
393721df3251SPerry Gibson             // Convert Python list to a SmallVector of MlirOperations
393821df3251SPerry Gibson             llvm::SmallVector<MlirOperation> exceptionOps;
3939b56d1ec6SPeter Hawkins             for (nb::handle exception : exceptions) {
3940b56d1ec6SPeter Hawkins               exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
394121df3251SPerry Gibson             }
394221df3251SPerry Gibson 
394321df3251SPerry Gibson             mlirValueReplaceAllUsesExcept(
394421df3251SPerry Gibson                 self, with, static_cast<intptr_t>(exceptionOps.size()),
394521df3251SPerry Gibson                 exceptionOps.data());
394621df3251SPerry Gibson           },
3947b56d1ec6SPeter Hawkins           nb::arg("with"), nb::arg("exceptions"),
394821df3251SPerry Gibson           kValueReplaceAllUsesExceptDocstring)
39497c850867SMaksim Levental       .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
39507c850867SMaksim Levental            [](PyValue &self) { return self.maybeDownCast(); });
3951436c6c9cSStella Laurenzo   PyBlockArgument::bind(m);
3952436c6c9cSStella Laurenzo   PyOpResult::bind(m);
3953afb2ed80SMike Urbach   PyOpOperand::bind(m);
3954436c6c9cSStella Laurenzo 
3955b56d1ec6SPeter Hawkins   nb::class_<PyAsmState>(m, "AsmState")
3956b56d1ec6SPeter Hawkins       .def(nb::init<PyValue &, bool>(), nb::arg("value"),
3957b56d1ec6SPeter Hawkins            nb::arg("use_local_scope") = false)
3958b56d1ec6SPeter Hawkins       .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
3959b56d1ec6SPeter Hawkins            nb::arg("use_local_scope") = false);
396075453714SJacques Pienaar 
396130d61893SAlex Zinenko   //----------------------------------------------------------------------------
396230d61893SAlex Zinenko   // Mapping of SymbolTable.
396330d61893SAlex Zinenko   //----------------------------------------------------------------------------
3964b56d1ec6SPeter Hawkins   nb::class_<PySymbolTable>(m, "SymbolTable")
3965b56d1ec6SPeter Hawkins       .def(nb::init<PyOperationBase &>())
396630d61893SAlex Zinenko       .def("__getitem__", &PySymbolTable::dunderGetItem)
3967b56d1ec6SPeter Hawkins       .def("insert", &PySymbolTable::insert, nb::arg("operation"))
3968b56d1ec6SPeter Hawkins       .def("erase", &PySymbolTable::erase, nb::arg("operation"))
396930d61893SAlex Zinenko       .def("__delitem__", &PySymbolTable::dunderDel)
3970bdc31837SStella Laurenzo       .def("__contains__",
3971bdc31837SStella Laurenzo            [](PySymbolTable &table, const std::string &name) {
397230d61893SAlex Zinenko              return !mlirOperationIsNull(mlirSymbolTableLookup(
397330d61893SAlex Zinenko                  table, mlirStringRefCreate(name.data(), name.length())));
3974bdc31837SStella Laurenzo            })
3975bdc31837SStella Laurenzo       // Static helpers.
3976bdc31837SStella Laurenzo       .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3977b56d1ec6SPeter Hawkins                   nb::arg("symbol"), nb::arg("name"))
3978bdc31837SStella Laurenzo       .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3979b56d1ec6SPeter Hawkins                   nb::arg("symbol"))
3980bdc31837SStella Laurenzo       .def_static("get_visibility", &PySymbolTable::getVisibility,
3981b56d1ec6SPeter Hawkins                   nb::arg("symbol"))
3982bdc31837SStella Laurenzo       .def_static("set_visibility", &PySymbolTable::setVisibility,
3983b56d1ec6SPeter Hawkins                   nb::arg("symbol"), nb::arg("visibility"))
3984bdc31837SStella Laurenzo       .def_static("replace_all_symbol_uses",
3985b56d1ec6SPeter Hawkins                   &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
3986b56d1ec6SPeter Hawkins                   nb::arg("new_symbol"), nb::arg("from_op"))
3987bdc31837SStella Laurenzo       .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3988b56d1ec6SPeter Hawkins                   nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
3989b56d1ec6SPeter Hawkins                   nb::arg("callback"));
399030d61893SAlex Zinenko 
3991436c6c9cSStella Laurenzo   // Container bindings.
3992436c6c9cSStella Laurenzo   PyBlockArgumentList::bind(m);
3993436c6c9cSStella Laurenzo   PyBlockIterator::bind(m);
3994436c6c9cSStella Laurenzo   PyBlockList::bind(m);
3995436c6c9cSStella Laurenzo   PyOperationIterator::bind(m);
3996436c6c9cSStella Laurenzo   PyOperationList::bind(m);
3997436c6c9cSStella Laurenzo   PyOpAttributeMap::bind(m);
3998afb2ed80SMike Urbach   PyOpOperandIterator::bind(m);
3999436c6c9cSStella Laurenzo   PyOpOperandList::bind(m);
4000436c6c9cSStella Laurenzo   PyOpResultList::bind(m);
4001d7e49736SMaksim Levental   PyOpSuccessors::bind(m);
4002436c6c9cSStella Laurenzo   PyRegionIterator::bind(m);
4003436c6c9cSStella Laurenzo   PyRegionList::bind(m);
40044acd8457SAlex Zinenko 
40054acd8457SAlex Zinenko   // Debug bindings.
40064acd8457SAlex Zinenko   PyGlobalDebugFlag::bind(m);
4007b57acb9aSJacques Pienaar 
4008b57acb9aSJacques Pienaar   // Attribute builder getter.
4009b57acb9aSJacques Pienaar   PyAttrBuilderMap::bind(m);
40103ea4c501SRahul Kayaith 
4011b56d1ec6SPeter Hawkins   nb::register_exception_translator([](const std::exception_ptr &p,
4012b56d1ec6SPeter Hawkins                                        void *payload) {
40133ea4c501SRahul Kayaith     // We can't define exceptions with custom fields through pybind, so instead
40143ea4c501SRahul Kayaith     // the exception class is defined in python and imported here.
40153ea4c501SRahul Kayaith     try {
40163ea4c501SRahul Kayaith       if (p)
40173ea4c501SRahul Kayaith         std::rethrow_exception(p);
40183ea4c501SRahul Kayaith     } catch (const MLIRError &e) {
4019b56d1ec6SPeter Hawkins       nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
40203ea4c501SRahul Kayaith                            .attr("MLIRError")(e.message, e.errorDiagnostics);
40213ea4c501SRahul Kayaith       PyErr_SetObject(PyExc_Exception, obj.ptr());
40223ea4c501SRahul Kayaith     }
40233ea4c501SRahul Kayaith   });
4024436c6c9cSStella Laurenzo }
4025