1436c6c9cSStella Laurenzo //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 9b56d1ec6SPeter Hawkins #include <optional> 10b56d1ec6SPeter Hawkins #include <utility> 11436c6c9cSStella Laurenzo 12436c6c9cSStella Laurenzo #include "Globals.h" 13b56d1ec6SPeter Hawkins #include "IRModule.h" 14b56d1ec6SPeter Hawkins #include "NanobindUtils.h" 15436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h" 164acd8457SAlex Zinenko #include "mlir-c/Debug.h" 173ea4c501SRahul Kayaith #include "mlir-c/Diagnostics.h" 183f3d1c90SMike Urbach #include "mlir-c/IR.h" 195c90e1ffSJacques Pienaar #include "mlir-c/Support.h" 205cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h" 21b56d1ec6SPeter Hawkins #include "mlir/Bindings/Python/NanobindAdaptors.h" 225cd42747SPeter Hawkins #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. 23e67cbbefSJacques Pienaar #include "llvm/ADT/ArrayRef.h" 24436c6c9cSStella Laurenzo #include "llvm/ADT/SmallVector.h" 25436c6c9cSStella Laurenzo 26b56d1ec6SPeter Hawkins namespace nb = nanobind; 27b56d1ec6SPeter Hawkins using namespace nb::literals; 28436c6c9cSStella Laurenzo using namespace mlir; 29436c6c9cSStella Laurenzo using namespace mlir::python; 30436c6c9cSStella Laurenzo 31436c6c9cSStella Laurenzo using llvm::SmallVector; 32436c6c9cSStella Laurenzo using llvm::StringRef; 33436c6c9cSStella Laurenzo using llvm::Twine; 34436c6c9cSStella Laurenzo 35436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 36436c6c9cSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline). 37436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 38436c6c9cSStella Laurenzo 39436c6c9cSStella Laurenzo static const char kContextParseTypeDocstring[] = 40436c6c9cSStella Laurenzo R"(Parses the assembly form of a type. 41436c6c9cSStella Laurenzo 423ea4c501SRahul Kayaith Returns a Type object or raises an MLIRError if the type cannot be parsed. 43436c6c9cSStella Laurenzo 44436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/#type-system 45436c6c9cSStella Laurenzo )"; 46436c6c9cSStella Laurenzo 47e67cbbefSJacques Pienaar static const char kContextGetCallSiteLocationDocstring[] = 48e67cbbefSJacques Pienaar R"(Gets a Location representing a caller and callsite)"; 49e67cbbefSJacques Pienaar 50436c6c9cSStella Laurenzo static const char kContextGetFileLocationDocstring[] = 51436c6c9cSStella Laurenzo R"(Gets a Location representing a file, line and column)"; 52436c6c9cSStella Laurenzo 53a77250fdSJacques Pienaar static const char kContextGetFileRangeDocstring[] = 54a77250fdSJacques Pienaar R"(Gets a Location representing a file, line and column range)"; 55a77250fdSJacques Pienaar 561ab3efacSJacques Pienaar static const char kContextGetFusedLocationDocstring[] = 571ab3efacSJacques Pienaar R"(Gets a Location representing a fused location with optional metadata)"; 581ab3efacSJacques Pienaar 5904d76d36SJacques Pienaar static const char kContextGetNameLocationDocString[] = 6004d76d36SJacques Pienaar R"(Gets a Location representing a named location with optional child location)"; 6104d76d36SJacques Pienaar 62436c6c9cSStella Laurenzo static const char kModuleParseDocstring[] = 63436c6c9cSStella Laurenzo R"(Parses a module's assembly format from a string. 64436c6c9cSStella Laurenzo 653ea4c501SRahul Kayaith Returns a new MlirModule or raises an MLIRError if the parsing fails. 66436c6c9cSStella Laurenzo 67436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/ 68436c6c9cSStella Laurenzo )"; 69436c6c9cSStella Laurenzo 70436c6c9cSStella Laurenzo static const char kOperationCreateDocstring[] = 71436c6c9cSStella Laurenzo R"(Creates a new operation. 72436c6c9cSStella Laurenzo 73436c6c9cSStella Laurenzo Args: 74436c6c9cSStella Laurenzo name: Operation name (e.g. "dialect.operation"). 75436c6c9cSStella Laurenzo results: Sequence of Type representing op result types. 76436c6c9cSStella Laurenzo attributes: Dict of str:Attribute. 77436c6c9cSStella Laurenzo successors: List of Block for the operation's successors. 78436c6c9cSStella Laurenzo regions: Number of regions to create. 79436c6c9cSStella Laurenzo location: A Location object (defaults to resolve from context manager). 80436c6c9cSStella Laurenzo ip: An InsertionPoint (defaults to resolve from context manager or set to 81436c6c9cSStella Laurenzo False to disable insertion, even with an insertion point set in the 82436c6c9cSStella Laurenzo context manager). 83f573bc24SJacques Pienaar infer_type: Whether to infer result types. 84436c6c9cSStella Laurenzo Returns: 85436c6c9cSStella Laurenzo A new "detached" Operation object. Detached operations can be added 86436c6c9cSStella Laurenzo to blocks, which causes them to become "attached." 87436c6c9cSStella Laurenzo )"; 88436c6c9cSStella Laurenzo 89436c6c9cSStella Laurenzo static const char kOperationPrintDocstring[] = 90436c6c9cSStella Laurenzo R"(Prints the assembly form of the operation to a file like object. 91436c6c9cSStella Laurenzo 92436c6c9cSStella Laurenzo Args: 93436c6c9cSStella Laurenzo file: The file like object to write to. Defaults to sys.stdout. 94436c6c9cSStella Laurenzo binary: Whether to write bytes (True) or str (False). Defaults to False. 95436c6c9cSStella Laurenzo large_elements_limit: Whether to elide elements attributes above this 96436c6c9cSStella Laurenzo number of elements. Defaults to None (no limit). 97436c6c9cSStella Laurenzo enable_debug_info: Whether to print debug/location information. Defaults 98436c6c9cSStella Laurenzo to False. 99436c6c9cSStella Laurenzo pretty_debug_info: Whether to format debug information for easier reading 100436c6c9cSStella Laurenzo by a human (warning: the result is unparseable). 101436c6c9cSStella Laurenzo print_generic_op_form: Whether to print the generic assembly forms of all 102436c6c9cSStella Laurenzo ops. Defaults to False. 103436c6c9cSStella Laurenzo use_local_Scope: Whether to print in a way that is more optimized for 104436c6c9cSStella Laurenzo multi-threaded access but may not be consistent with how the overall 105436c6c9cSStella Laurenzo module prints. 106ace1d0adSStella Laurenzo assume_verified: By default, if not printing generic form, the verifier 107ace1d0adSStella Laurenzo will be run and if it fails, generic form will be printed with a comment 108ace1d0adSStella Laurenzo about failed verification. While a reasonable default for interactive use, 109ace1d0adSStella Laurenzo for systematic use, it is often better for the caller to verify explicitly 110ace1d0adSStella Laurenzo and report failures in a more robust fashion. Set this to True if doing this 111ace1d0adSStella Laurenzo in order to avoid running a redundant verification. If the IR is actually 112ace1d0adSStella Laurenzo invalid, behavior is undefined. 113abad8455SJonas Rickert skip_regions: Whether to skip printing regions. Defaults to False. 114436c6c9cSStella Laurenzo )"; 115436c6c9cSStella Laurenzo 116204acc5cSJacques Pienaar static const char kOperationPrintStateDocstring[] = 117204acc5cSJacques Pienaar R"(Prints the assembly form of the operation to a file like object. 118204acc5cSJacques Pienaar 119204acc5cSJacques Pienaar Args: 120204acc5cSJacques Pienaar file: The file like object to write to. Defaults to sys.stdout. 121204acc5cSJacques Pienaar binary: Whether to write bytes (True) or str (False). Defaults to False. 122204acc5cSJacques Pienaar state: AsmState capturing the operation numbering and flags. 123204acc5cSJacques Pienaar )"; 124204acc5cSJacques Pienaar 125436c6c9cSStella Laurenzo static const char kOperationGetAsmDocstring[] = 126436c6c9cSStella Laurenzo R"(Gets the assembly form of the operation with all options available. 127436c6c9cSStella Laurenzo 128436c6c9cSStella Laurenzo Args: 129436c6c9cSStella Laurenzo binary: Whether to return a bytes (True) or str (False) object. Defaults to 130436c6c9cSStella Laurenzo False. 131436c6c9cSStella Laurenzo ... others ...: See the print() method for common keyword arguments for 132436c6c9cSStella Laurenzo configuring the printout. 133436c6c9cSStella Laurenzo Returns: 134436c6c9cSStella Laurenzo Either a bytes or str object, depending on the setting of the 'binary' 135436c6c9cSStella Laurenzo argument. 136436c6c9cSStella Laurenzo )"; 137436c6c9cSStella Laurenzo 13889418ddcSMehdi Amini static const char kOperationPrintBytecodeDocstring[] = 13989418ddcSMehdi Amini R"(Write the bytecode form of the operation to a file like object. 14089418ddcSMehdi Amini 14189418ddcSMehdi Amini Args: 14289418ddcSMehdi Amini file: The file like object to write to. 1430610e2f6SJacques Pienaar desired_version: The version of bytecode to emit. 1440610e2f6SJacques Pienaar Returns: 1450610e2f6SJacques Pienaar The bytecode writer status. 14689418ddcSMehdi Amini )"; 14789418ddcSMehdi Amini 148436c6c9cSStella Laurenzo static const char kOperationStrDunderDocstring[] = 149436c6c9cSStella Laurenzo R"(Gets the assembly form of the operation with default options. 150436c6c9cSStella Laurenzo 151436c6c9cSStella Laurenzo If more advanced control over the assembly formatting or I/O options is needed, 152436c6c9cSStella Laurenzo use the dedicated print or get_asm method, which supports keyword arguments to 153436c6c9cSStella Laurenzo customize behavior. 154436c6c9cSStella Laurenzo )"; 155436c6c9cSStella Laurenzo 156436c6c9cSStella Laurenzo static const char kDumpDocstring[] = 157436c6c9cSStella Laurenzo R"(Dumps a debug representation of the object to stderr.)"; 158436c6c9cSStella Laurenzo 159436c6c9cSStella Laurenzo static const char kAppendBlockDocstring[] = 160436c6c9cSStella Laurenzo R"(Appends a new block, with argument types as positional args. 161436c6c9cSStella Laurenzo 162436c6c9cSStella Laurenzo Returns: 163436c6c9cSStella Laurenzo The created block. 164436c6c9cSStella Laurenzo )"; 165436c6c9cSStella Laurenzo 166436c6c9cSStella Laurenzo static const char kValueDunderStrDocstring[] = 167436c6c9cSStella Laurenzo R"(Returns the string form of the value. 168436c6c9cSStella Laurenzo 169436c6c9cSStella Laurenzo If the value is a block argument, this is the assembly form of its type and the 170436c6c9cSStella Laurenzo position in the argument list. If the value is an operation result, this is 171436c6c9cSStella Laurenzo equivalent to printing the operation that produced it. 172436c6c9cSStella Laurenzo )"; 173436c6c9cSStella Laurenzo 17481233c70Smax static const char kGetNameAsOperand[] = 17581233c70Smax R"(Returns the string form of value as an operand (i.e., the ValueID). 17681233c70Smax )"; 17781233c70Smax 1785b303f21Smax static const char kValueReplaceAllUsesWithDocstring[] = 1795b303f21Smax R"(Replace all uses of value with the new value, updating anything in 1805b303f21Smax the IR that uses 'self' to use the other value instead. 1815b303f21Smax )"; 1825b303f21Smax 18321df3251SPerry Gibson static const char kValueReplaceAllUsesExceptDocstring[] = 18421df3251SPerry Gibson R"("Replace all uses of this value with the 'with' value, except for those 18521df3251SPerry Gibson in 'exceptions'. 'exceptions' can be either a single operation or a list of 18621df3251SPerry Gibson operations. 18721df3251SPerry Gibson )"; 18821df3251SPerry Gibson 189436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 190436c6c9cSStella Laurenzo // Utilities. 191436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 192436c6c9cSStella Laurenzo 1934acd8457SAlex Zinenko /// Helper for creating an @classmethod. 194436c6c9cSStella Laurenzo template <class Func, typename... Args> 195b56d1ec6SPeter Hawkins nb::object classmethod(Func f, Args... args) { 196b56d1ec6SPeter Hawkins nb::object cf = nb::cpp_function(f, args...); 197b56d1ec6SPeter Hawkins return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr()))); 198436c6c9cSStella Laurenzo } 199436c6c9cSStella Laurenzo 200b56d1ec6SPeter Hawkins static nb::object 201436c6c9cSStella Laurenzo createCustomDialectWrapper(const std::string &dialectNamespace, 202b56d1ec6SPeter Hawkins nb::object dialectDescriptor) { 203436c6c9cSStella Laurenzo auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 204436c6c9cSStella Laurenzo if (!dialectClass) { 205436c6c9cSStella Laurenzo // Use the base class. 206b56d1ec6SPeter Hawkins return nb::cast(PyDialect(std::move(dialectDescriptor))); 207436c6c9cSStella Laurenzo } 208436c6c9cSStella Laurenzo 209436c6c9cSStella Laurenzo // Create the custom implementation. 210436c6c9cSStella Laurenzo return (*dialectClass)(std::move(dialectDescriptor)); 211436c6c9cSStella Laurenzo } 212436c6c9cSStella Laurenzo 213436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) { 214436c6c9cSStella Laurenzo return mlirStringRefCreate(s.data(), s.size()); 215436c6c9cSStella Laurenzo } 216436c6c9cSStella Laurenzo 217f4125e02SPeter Hawkins static MlirStringRef toMlirStringRef(std::string_view s) { 218f4125e02SPeter Hawkins return mlirStringRefCreate(s.data(), s.size()); 219f4125e02SPeter Hawkins } 220f4125e02SPeter Hawkins 221b56d1ec6SPeter Hawkins static MlirStringRef toMlirStringRef(const nb::bytes &s) { 222b56d1ec6SPeter Hawkins return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size()); 223b56d1ec6SPeter Hawkins } 224b56d1ec6SPeter Hawkins 225514dddbeSRahul Kayaith /// Create a block, using the current location context if no locations are 226514dddbeSRahul Kayaith /// specified. 227b56d1ec6SPeter Hawkins static MlirBlock createBlock(const nb::sequence &pyArgTypes, 228b56d1ec6SPeter Hawkins const std::optional<nb::sequence> &pyArgLocs) { 229514dddbeSRahul Kayaith SmallVector<MlirType> argTypes; 230b56d1ec6SPeter Hawkins argTypes.reserve(nb::len(pyArgTypes)); 231514dddbeSRahul Kayaith for (const auto &pyType : pyArgTypes) 232b56d1ec6SPeter Hawkins argTypes.push_back(nb::cast<PyType &>(pyType)); 233514dddbeSRahul Kayaith 234514dddbeSRahul Kayaith SmallVector<MlirLocation> argLocs; 235514dddbeSRahul Kayaith if (pyArgLocs) { 236b56d1ec6SPeter Hawkins argLocs.reserve(nb::len(*pyArgLocs)); 237514dddbeSRahul Kayaith for (const auto &pyLoc : *pyArgLocs) 238b56d1ec6SPeter Hawkins argLocs.push_back(nb::cast<PyLocation &>(pyLoc)); 239514dddbeSRahul Kayaith } else if (!argTypes.empty()) { 240514dddbeSRahul Kayaith argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve()); 241514dddbeSRahul Kayaith } 242514dddbeSRahul Kayaith 243514dddbeSRahul Kayaith if (argTypes.size() != argLocs.size()) 244b56d1ec6SPeter Hawkins throw nb::value_error(("Expected " + Twine(argTypes.size()) + 245514dddbeSRahul Kayaith " locations, got: " + Twine(argLocs.size())) 246b56d1ec6SPeter Hawkins .str() 247b56d1ec6SPeter Hawkins .c_str()); 248514dddbeSRahul Kayaith return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); 249514dddbeSRahul Kayaith } 250514dddbeSRahul Kayaith 2514acd8457SAlex Zinenko /// Wrapper for the global LLVM debugging flag. 2524acd8457SAlex Zinenko struct PyGlobalDebugFlag { 253f136c800Svfdev static void set(nb::object &o, bool enable) { 254f136c800Svfdev nb::ft_lock_guard lock(mutex); 255f136c800Svfdev mlirEnableGlobalDebug(enable); 256f136c800Svfdev } 2574acd8457SAlex Zinenko 258f136c800Svfdev static bool get(const nb::object &) { 259f136c800Svfdev nb::ft_lock_guard lock(mutex); 260f136c800Svfdev return mlirIsGlobalDebugEnabled(); 261f136c800Svfdev } 2624acd8457SAlex Zinenko 263b56d1ec6SPeter Hawkins static void bind(nb::module_ &m) { 2644acd8457SAlex Zinenko // Debug flags. 265b56d1ec6SPeter Hawkins nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug") 266b56d1ec6SPeter Hawkins .def_prop_rw_static("flag", &PyGlobalDebugFlag::get, 2678f21909cSOleksandr "Alex" Zinenko &PyGlobalDebugFlag::set, "LLVM-wide debug flag") 2688f21909cSOleksandr "Alex" Zinenko .def_static( 2698f21909cSOleksandr "Alex" Zinenko "set_types", 2708f21909cSOleksandr "Alex" Zinenko [](const std::string &type) { 271f136c800Svfdev nb::ft_lock_guard lock(mutex); 2728f21909cSOleksandr "Alex" Zinenko mlirSetGlobalDebugType(type.c_str()); 2738f21909cSOleksandr "Alex" Zinenko }, 2748f21909cSOleksandr "Alex" Zinenko "types"_a, "Sets specific debug types to be produced by LLVM") 2758f21909cSOleksandr "Alex" Zinenko .def_static("set_types", [](const std::vector<std::string> &types) { 2768f21909cSOleksandr "Alex" Zinenko std::vector<const char *> pointers; 2778f21909cSOleksandr "Alex" Zinenko pointers.reserve(types.size()); 2788f21909cSOleksandr "Alex" Zinenko for (const std::string &str : types) 2798f21909cSOleksandr "Alex" Zinenko pointers.push_back(str.c_str()); 280f136c800Svfdev nb::ft_lock_guard lock(mutex); 2818f21909cSOleksandr "Alex" Zinenko mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); 2828f21909cSOleksandr "Alex" Zinenko }); 2834acd8457SAlex Zinenko } 284f136c800Svfdev 285f136c800Svfdev private: 286f136c800Svfdev static nb::ft_mutex mutex; 2874acd8457SAlex Zinenko }; 2884acd8457SAlex Zinenko 289f136c800Svfdev nb::ft_mutex PyGlobalDebugFlag::mutex; 290f136c800Svfdev 291b57acb9aSJacques Pienaar struct PyAttrBuilderMap { 292b57acb9aSJacques Pienaar static bool dunderContains(const std::string &attributeKind) { 293b57acb9aSJacques Pienaar return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); 294b57acb9aSJacques Pienaar } 295a0f5bbcfSvfdev static nb::callable dunderGetItemNamed(const std::string &attributeKind) { 296b57acb9aSJacques Pienaar auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); 297b57acb9aSJacques Pienaar if (!builder) 298b56d1ec6SPeter Hawkins throw nb::key_error(attributeKind.c_str()); 299b57acb9aSJacques Pienaar return *builder; 300b57acb9aSJacques Pienaar } 301a0f5bbcfSvfdev static void dunderSetItemNamed(const std::string &attributeKind, 302b56d1ec6SPeter Hawkins nb::callable func, bool replace) { 30392233062Smax PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), 30492233062Smax replace); 305b57acb9aSJacques Pienaar } 306b57acb9aSJacques Pienaar 307b56d1ec6SPeter Hawkins static void bind(nb::module_ &m) { 308b56d1ec6SPeter Hawkins nb::class_<PyAttrBuilderMap>(m, "AttrBuilder") 309b57acb9aSJacques Pienaar .def_static("contains", &PyAttrBuilderMap::dunderContains) 310a0f5bbcfSvfdev .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed) 311a0f5bbcfSvfdev .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed, 31292233062Smax "attribute_kind"_a, "attr_builder"_a, "replace"_a = false, 31392233062Smax "Register an attribute builder for building MLIR " 31492233062Smax "attributes from python values."); 315b57acb9aSJacques Pienaar } 316b57acb9aSJacques Pienaar }; 317b57acb9aSJacques Pienaar 318436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 319c83318e3SAdam Paszke // PyBlock 320c83318e3SAdam Paszke //------------------------------------------------------------------------------ 321c83318e3SAdam Paszke 322b56d1ec6SPeter Hawkins nb::object PyBlock::getCapsule() { 323b56d1ec6SPeter Hawkins return nb::steal<nb::object>(mlirPythonBlockToCapsule(get())); 324c83318e3SAdam Paszke } 325c83318e3SAdam Paszke 326c83318e3SAdam Paszke //------------------------------------------------------------------------------ 327436c6c9cSStella Laurenzo // Collections. 328436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 329436c6c9cSStella Laurenzo 330436c6c9cSStella Laurenzo namespace { 331436c6c9cSStella Laurenzo 332436c6c9cSStella Laurenzo class PyRegionIterator { 333436c6c9cSStella Laurenzo public: 334436c6c9cSStella Laurenzo PyRegionIterator(PyOperationRef operation) 335436c6c9cSStella Laurenzo : operation(std::move(operation)) {} 336436c6c9cSStella Laurenzo 337436c6c9cSStella Laurenzo PyRegionIterator &dunderIter() { return *this; } 338436c6c9cSStella Laurenzo 339436c6c9cSStella Laurenzo PyRegion dunderNext() { 340436c6c9cSStella Laurenzo operation->checkValid(); 341436c6c9cSStella Laurenzo if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 342b56d1ec6SPeter Hawkins throw nb::stop_iteration(); 343436c6c9cSStella Laurenzo } 344436c6c9cSStella Laurenzo MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 345436c6c9cSStella Laurenzo return PyRegion(operation, region); 346436c6c9cSStella Laurenzo } 347436c6c9cSStella Laurenzo 348b56d1ec6SPeter Hawkins static void bind(nb::module_ &m) { 349b56d1ec6SPeter Hawkins nb::class_<PyRegionIterator>(m, "RegionIterator") 350436c6c9cSStella Laurenzo .def("__iter__", &PyRegionIterator::dunderIter) 351436c6c9cSStella Laurenzo .def("__next__", &PyRegionIterator::dunderNext); 352436c6c9cSStella Laurenzo } 353436c6c9cSStella Laurenzo 354436c6c9cSStella Laurenzo private: 355436c6c9cSStella Laurenzo PyOperationRef operation; 356436c6c9cSStella Laurenzo int nextIndex = 0; 357436c6c9cSStella Laurenzo }; 358436c6c9cSStella Laurenzo 359436c6c9cSStella Laurenzo /// Regions of an op are fixed length and indexed numerically so are represented 360436c6c9cSStella Laurenzo /// with a sequence-like container. 361436c6c9cSStella Laurenzo class PyRegionList { 362436c6c9cSStella Laurenzo public: 363436c6c9cSStella Laurenzo PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 364436c6c9cSStella Laurenzo 365d0d26ee7SRahul Kayaith PyRegionIterator dunderIter() { 366d0d26ee7SRahul Kayaith operation->checkValid(); 367d0d26ee7SRahul Kayaith return PyRegionIterator(operation); 368d0d26ee7SRahul Kayaith } 369d0d26ee7SRahul Kayaith 370436c6c9cSStella Laurenzo intptr_t dunderLen() { 371436c6c9cSStella Laurenzo operation->checkValid(); 372436c6c9cSStella Laurenzo return mlirOperationGetNumRegions(operation->get()); 373436c6c9cSStella Laurenzo } 374436c6c9cSStella Laurenzo 375436c6c9cSStella Laurenzo PyRegion dunderGetItem(intptr_t index) { 376436c6c9cSStella Laurenzo // dunderLen checks validity. 377436c6c9cSStella Laurenzo if (index < 0 || index >= dunderLen()) { 378b56d1ec6SPeter Hawkins throw nb::index_error("attempt to access out of bounds region"); 379436c6c9cSStella Laurenzo } 380436c6c9cSStella Laurenzo MlirRegion region = mlirOperationGetRegion(operation->get(), index); 381436c6c9cSStella Laurenzo return PyRegion(operation, region); 382436c6c9cSStella Laurenzo } 383436c6c9cSStella Laurenzo 384b56d1ec6SPeter Hawkins static void bind(nb::module_ &m) { 385b56d1ec6SPeter Hawkins nb::class_<PyRegionList>(m, "RegionSequence") 386436c6c9cSStella Laurenzo .def("__len__", &PyRegionList::dunderLen) 387d0d26ee7SRahul Kayaith .def("__iter__", &PyRegionList::dunderIter) 388436c6c9cSStella Laurenzo .def("__getitem__", &PyRegionList::dunderGetItem); 389436c6c9cSStella Laurenzo } 390436c6c9cSStella Laurenzo 391436c6c9cSStella Laurenzo private: 392436c6c9cSStella Laurenzo PyOperationRef operation; 393436c6c9cSStella Laurenzo }; 394436c6c9cSStella Laurenzo 395436c6c9cSStella Laurenzo class PyBlockIterator { 396436c6c9cSStella Laurenzo public: 397436c6c9cSStella Laurenzo PyBlockIterator(PyOperationRef operation, MlirBlock next) 398436c6c9cSStella Laurenzo : operation(std::move(operation)), next(next) {} 399436c6c9cSStella Laurenzo 400436c6c9cSStella Laurenzo PyBlockIterator &dunderIter() { return *this; } 401436c6c9cSStella Laurenzo 402436c6c9cSStella Laurenzo PyBlock dunderNext() { 403436c6c9cSStella Laurenzo operation->checkValid(); 404436c6c9cSStella Laurenzo if (mlirBlockIsNull(next)) { 405b56d1ec6SPeter Hawkins throw nb::stop_iteration(); 406436c6c9cSStella Laurenzo } 407436c6c9cSStella Laurenzo 408436c6c9cSStella Laurenzo PyBlock returnBlock(operation, next); 409436c6c9cSStella Laurenzo next = mlirBlockGetNextInRegion(next); 410436c6c9cSStella Laurenzo return returnBlock; 411436c6c9cSStella Laurenzo } 412436c6c9cSStella Laurenzo 413b56d1ec6SPeter Hawkins static void bind(nb::module_ &m) { 414b56d1ec6SPeter Hawkins nb::class_<PyBlockIterator>(m, "BlockIterator") 415436c6c9cSStella Laurenzo .def("__iter__", &PyBlockIterator::dunderIter) 416436c6c9cSStella Laurenzo .def("__next__", &PyBlockIterator::dunderNext); 417436c6c9cSStella Laurenzo } 418436c6c9cSStella Laurenzo 419436c6c9cSStella Laurenzo private: 420436c6c9cSStella Laurenzo PyOperationRef operation; 421436c6c9cSStella Laurenzo MlirBlock next; 422436c6c9cSStella Laurenzo }; 423436c6c9cSStella Laurenzo 424436c6c9cSStella Laurenzo /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 425436c6c9cSStella Laurenzo /// we present them as a more full-featured list-like container but optimize 426436c6c9cSStella Laurenzo /// it for forward iteration. Blocks are always owned by a region. 427436c6c9cSStella Laurenzo class PyBlockList { 428436c6c9cSStella Laurenzo public: 429436c6c9cSStella Laurenzo PyBlockList(PyOperationRef operation, MlirRegion region) 430436c6c9cSStella Laurenzo : operation(std::move(operation)), region(region) {} 431436c6c9cSStella Laurenzo 432436c6c9cSStella Laurenzo PyBlockIterator dunderIter() { 433436c6c9cSStella Laurenzo operation->checkValid(); 434436c6c9cSStella Laurenzo return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 435436c6c9cSStella Laurenzo } 436436c6c9cSStella Laurenzo 437436c6c9cSStella Laurenzo intptr_t dunderLen() { 438436c6c9cSStella Laurenzo operation->checkValid(); 439436c6c9cSStella Laurenzo intptr_t count = 0; 440436c6c9cSStella Laurenzo MlirBlock block = mlirRegionGetFirstBlock(region); 441436c6c9cSStella Laurenzo while (!mlirBlockIsNull(block)) { 442436c6c9cSStella Laurenzo count += 1; 443436c6c9cSStella Laurenzo block = mlirBlockGetNextInRegion(block); 444436c6c9cSStella Laurenzo } 445436c6c9cSStella Laurenzo return count; 446436c6c9cSStella Laurenzo } 447436c6c9cSStella Laurenzo 448436c6c9cSStella Laurenzo PyBlock dunderGetItem(intptr_t index) { 449436c6c9cSStella Laurenzo operation->checkValid(); 450436c6c9cSStella Laurenzo if (index < 0) { 451b56d1ec6SPeter Hawkins throw nb::index_error("attempt to access out of bounds block"); 452436c6c9cSStella Laurenzo } 453436c6c9cSStella Laurenzo MlirBlock block = mlirRegionGetFirstBlock(region); 454436c6c9cSStella Laurenzo while (!mlirBlockIsNull(block)) { 455436c6c9cSStella Laurenzo if (index == 0) { 456436c6c9cSStella Laurenzo return PyBlock(operation, block); 457436c6c9cSStella Laurenzo } 458436c6c9cSStella Laurenzo block = mlirBlockGetNextInRegion(block); 459436c6c9cSStella Laurenzo index -= 1; 460436c6c9cSStella Laurenzo } 461b56d1ec6SPeter Hawkins throw nb::index_error("attempt to access out of bounds block"); 462436c6c9cSStella Laurenzo } 463436c6c9cSStella Laurenzo 464b56d1ec6SPeter Hawkins PyBlock appendBlock(const nb::args &pyArgTypes, 465b56d1ec6SPeter Hawkins const std::optional<nb::sequence> &pyArgLocs) { 466436c6c9cSStella Laurenzo operation->checkValid(); 467b56d1ec6SPeter Hawkins MlirBlock block = 468b56d1ec6SPeter Hawkins createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs); 469436c6c9cSStella Laurenzo mlirRegionAppendOwnedBlock(region, block); 470436c6c9cSStella Laurenzo return PyBlock(operation, block); 471436c6c9cSStella Laurenzo } 472436c6c9cSStella Laurenzo 473b56d1ec6SPeter Hawkins static void bind(nb::module_ &m) { 474b56d1ec6SPeter Hawkins nb::class_<PyBlockList>(m, "BlockList") 475436c6c9cSStella Laurenzo .def("__getitem__", &PyBlockList::dunderGetItem) 476436c6c9cSStella Laurenzo .def("__iter__", &PyBlockList::dunderIter) 477436c6c9cSStella Laurenzo .def("__len__", &PyBlockList::dunderLen) 478514dddbeSRahul Kayaith .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring, 479b56d1ec6SPeter Hawkins nb::arg("args"), nb::kw_only(), 480b56d1ec6SPeter Hawkins nb::arg("arg_locs") = std::nullopt); 481436c6c9cSStella Laurenzo } 482436c6c9cSStella Laurenzo 483436c6c9cSStella Laurenzo private: 484436c6c9cSStella Laurenzo PyOperationRef operation; 485436c6c9cSStella Laurenzo MlirRegion region; 486436c6c9cSStella Laurenzo }; 487436c6c9cSStella Laurenzo 488436c6c9cSStella Laurenzo class PyOperationIterator { 489436c6c9cSStella Laurenzo public: 490436c6c9cSStella Laurenzo PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 491436c6c9cSStella Laurenzo : parentOperation(std::move(parentOperation)), next(next) {} 492436c6c9cSStella Laurenzo 493436c6c9cSStella Laurenzo PyOperationIterator &dunderIter() { return *this; } 494436c6c9cSStella Laurenzo 495b56d1ec6SPeter Hawkins nb::object dunderNext() { 496436c6c9cSStella Laurenzo parentOperation->checkValid(); 497436c6c9cSStella Laurenzo if (mlirOperationIsNull(next)) { 498b56d1ec6SPeter Hawkins throw nb::stop_iteration(); 499436c6c9cSStella Laurenzo } 500436c6c9cSStella Laurenzo 501436c6c9cSStella Laurenzo PyOperationRef returnOperation = 502436c6c9cSStella Laurenzo PyOperation::forOperation(parentOperation->getContext(), next); 503436c6c9cSStella Laurenzo next = mlirOperationGetNextInBlock(next); 504436c6c9cSStella Laurenzo return returnOperation->createOpView(); 505436c6c9cSStella Laurenzo } 506436c6c9cSStella Laurenzo 507b56d1ec6SPeter Hawkins static void bind(nb::module_ &m) { 508b56d1ec6SPeter Hawkins nb::class_<PyOperationIterator>(m, "OperationIterator") 509436c6c9cSStella Laurenzo .def("__iter__", &PyOperationIterator::dunderIter) 510436c6c9cSStella Laurenzo .def("__next__", &PyOperationIterator::dunderNext); 511436c6c9cSStella Laurenzo } 512436c6c9cSStella Laurenzo 513436c6c9cSStella Laurenzo private: 514436c6c9cSStella Laurenzo PyOperationRef parentOperation; 515436c6c9cSStella Laurenzo MlirOperation next; 516436c6c9cSStella Laurenzo }; 517436c6c9cSStella Laurenzo 518436c6c9cSStella Laurenzo /// Operations are exposed by the C-API as a forward-only linked list. In 519436c6c9cSStella Laurenzo /// Python, we present them as a more full-featured list-like container but 520436c6c9cSStella Laurenzo /// optimize it for forward iteration. Iterable operations are always owned 521436c6c9cSStella Laurenzo /// by a block. 522436c6c9cSStella Laurenzo class PyOperationList { 523436c6c9cSStella Laurenzo public: 524436c6c9cSStella Laurenzo PyOperationList(PyOperationRef parentOperation, MlirBlock block) 525436c6c9cSStella Laurenzo : parentOperation(std::move(parentOperation)), block(block) {} 526436c6c9cSStella Laurenzo 527436c6c9cSStella Laurenzo PyOperationIterator dunderIter() { 528436c6c9cSStella Laurenzo parentOperation->checkValid(); 529436c6c9cSStella Laurenzo return PyOperationIterator(parentOperation, 530436c6c9cSStella Laurenzo mlirBlockGetFirstOperation(block)); 531436c6c9cSStella Laurenzo } 532436c6c9cSStella Laurenzo 533436c6c9cSStella Laurenzo intptr_t dunderLen() { 534436c6c9cSStella Laurenzo parentOperation->checkValid(); 535436c6c9cSStella Laurenzo intptr_t count = 0; 536436c6c9cSStella Laurenzo MlirOperation childOp = mlirBlockGetFirstOperation(block); 537436c6c9cSStella Laurenzo while (!mlirOperationIsNull(childOp)) { 538436c6c9cSStella Laurenzo count += 1; 539436c6c9cSStella Laurenzo childOp = mlirOperationGetNextInBlock(childOp); 540436c6c9cSStella Laurenzo } 541436c6c9cSStella Laurenzo return count; 542436c6c9cSStella Laurenzo } 543436c6c9cSStella Laurenzo 544b56d1ec6SPeter Hawkins nb::object dunderGetItem(intptr_t index) { 545436c6c9cSStella Laurenzo parentOperation->checkValid(); 546436c6c9cSStella Laurenzo if (index < 0) { 547b56d1ec6SPeter Hawkins throw nb::index_error("attempt to access out of bounds operation"); 548436c6c9cSStella Laurenzo } 549436c6c9cSStella Laurenzo MlirOperation childOp = mlirBlockGetFirstOperation(block); 550436c6c9cSStella Laurenzo while (!mlirOperationIsNull(childOp)) { 551436c6c9cSStella Laurenzo if (index == 0) { 552436c6c9cSStella Laurenzo return PyOperation::forOperation(parentOperation->getContext(), childOp) 553436c6c9cSStella Laurenzo ->createOpView(); 554436c6c9cSStella Laurenzo } 555436c6c9cSStella Laurenzo childOp = mlirOperationGetNextInBlock(childOp); 556436c6c9cSStella Laurenzo index -= 1; 557436c6c9cSStella Laurenzo } 558b56d1ec6SPeter Hawkins throw nb::index_error("attempt to access out of bounds operation"); 559436c6c9cSStella Laurenzo } 560436c6c9cSStella Laurenzo 561b56d1ec6SPeter Hawkins static void bind(nb::module_ &m) { 562b56d1ec6SPeter Hawkins nb::class_<PyOperationList>(m, "OperationList") 563436c6c9cSStella Laurenzo .def("__getitem__", &PyOperationList::dunderGetItem) 564436c6c9cSStella Laurenzo .def("__iter__", &PyOperationList::dunderIter) 565436c6c9cSStella Laurenzo .def("__len__", &PyOperationList::dunderLen); 566436c6c9cSStella Laurenzo } 567436c6c9cSStella Laurenzo 568436c6c9cSStella Laurenzo private: 569436c6c9cSStella Laurenzo PyOperationRef parentOperation; 570436c6c9cSStella Laurenzo MlirBlock block; 571436c6c9cSStella Laurenzo }; 572436c6c9cSStella Laurenzo 573afb2ed80SMike Urbach class PyOpOperand { 574afb2ed80SMike Urbach public: 575afb2ed80SMike Urbach PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {} 576afb2ed80SMike Urbach 577b56d1ec6SPeter Hawkins nb::object getOwner() { 578afb2ed80SMike Urbach MlirOperation owner = mlirOpOperandGetOwner(opOperand); 579afb2ed80SMike Urbach PyMlirContextRef context = 580afb2ed80SMike Urbach PyMlirContext::forContext(mlirOperationGetContext(owner)); 581afb2ed80SMike Urbach return PyOperation::forOperation(context, owner)->createOpView(); 582afb2ed80SMike Urbach } 583afb2ed80SMike Urbach 584afb2ed80SMike Urbach size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); } 585afb2ed80SMike Urbach 586b56d1ec6SPeter Hawkins static void bind(nb::module_ &m) { 587b56d1ec6SPeter Hawkins nb::class_<PyOpOperand>(m, "OpOperand") 588b56d1ec6SPeter Hawkins .def_prop_ro("owner", &PyOpOperand::getOwner) 589b56d1ec6SPeter Hawkins .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber); 590afb2ed80SMike Urbach } 591afb2ed80SMike Urbach 592afb2ed80SMike Urbach private: 593afb2ed80SMike Urbach MlirOpOperand opOperand; 594afb2ed80SMike Urbach }; 595afb2ed80SMike Urbach 596afb2ed80SMike Urbach class PyOpOperandIterator { 597afb2ed80SMike Urbach public: 598afb2ed80SMike Urbach PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {} 599afb2ed80SMike Urbach 600afb2ed80SMike Urbach PyOpOperandIterator &dunderIter() { return *this; } 601afb2ed80SMike Urbach 602afb2ed80SMike Urbach PyOpOperand dunderNext() { 603afb2ed80SMike Urbach if (mlirOpOperandIsNull(opOperand)) 604b56d1ec6SPeter Hawkins throw nb::stop_iteration(); 605afb2ed80SMike Urbach 606afb2ed80SMike Urbach PyOpOperand returnOpOperand(opOperand); 607afb2ed80SMike Urbach opOperand = mlirOpOperandGetNextUse(opOperand); 608afb2ed80SMike Urbach return returnOpOperand; 609afb2ed80SMike Urbach } 610afb2ed80SMike Urbach 611b56d1ec6SPeter Hawkins static void bind(nb::module_ &m) { 612b56d1ec6SPeter Hawkins nb::class_<PyOpOperandIterator>(m, "OpOperandIterator") 613afb2ed80SMike Urbach .def("__iter__", &PyOpOperandIterator::dunderIter) 614afb2ed80SMike Urbach .def("__next__", &PyOpOperandIterator::dunderNext); 615afb2ed80SMike Urbach } 616afb2ed80SMike Urbach 617afb2ed80SMike Urbach private: 618afb2ed80SMike Urbach MlirOpOperand opOperand; 619afb2ed80SMike Urbach }; 620afb2ed80SMike Urbach 621436c6c9cSStella Laurenzo } // namespace 622436c6c9cSStella Laurenzo 623436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 624436c6c9cSStella Laurenzo // PyMlirContext 625436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 626436c6c9cSStella Laurenzo 627436c6c9cSStella Laurenzo PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 628b56d1ec6SPeter Hawkins nb::gil_scoped_acquire acquire; 629f136c800Svfdev nb::ft_lock_guard lock(live_contexts_mutex); 630436c6c9cSStella Laurenzo auto &liveContexts = getLiveContexts(); 631436c6c9cSStella Laurenzo liveContexts[context.ptr] = this; 632436c6c9cSStella Laurenzo } 633436c6c9cSStella Laurenzo 634436c6c9cSStella Laurenzo PyMlirContext::~PyMlirContext() { 635436c6c9cSStella Laurenzo // Note that the only public way to construct an instance is via the 636436c6c9cSStella Laurenzo // forContext method, which always puts the associated handle into 637436c6c9cSStella Laurenzo // liveContexts. 638b56d1ec6SPeter Hawkins nb::gil_scoped_acquire acquire; 639f136c800Svfdev { 640f136c800Svfdev nb::ft_lock_guard lock(live_contexts_mutex); 641436c6c9cSStella Laurenzo getLiveContexts().erase(context.ptr); 642f136c800Svfdev } 643436c6c9cSStella Laurenzo mlirContextDestroy(context); 644436c6c9cSStella Laurenzo } 645436c6c9cSStella Laurenzo 646b56d1ec6SPeter Hawkins nb::object PyMlirContext::getCapsule() { 647b56d1ec6SPeter Hawkins return nb::steal<nb::object>(mlirPythonContextToCapsule(get())); 648436c6c9cSStella Laurenzo } 649436c6c9cSStella Laurenzo 650b56d1ec6SPeter Hawkins nb::object PyMlirContext::createFromCapsule(nb::object capsule) { 651436c6c9cSStella Laurenzo MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 652436c6c9cSStella Laurenzo if (mlirContextIsNull(rawContext)) 653b56d1ec6SPeter Hawkins throw nb::python_error(); 65478bd1246SAlex Zinenko return forContext(rawContext).releaseObject(); 655436c6c9cSStella Laurenzo } 656436c6c9cSStella Laurenzo 657436c6c9cSStella Laurenzo PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 658b56d1ec6SPeter Hawkins nb::gil_scoped_acquire acquire; 659f136c800Svfdev nb::ft_lock_guard lock(live_contexts_mutex); 660436c6c9cSStella Laurenzo auto &liveContexts = getLiveContexts(); 661436c6c9cSStella Laurenzo auto it = liveContexts.find(context.ptr); 662436c6c9cSStella Laurenzo if (it == liveContexts.end()) { 66378bd1246SAlex Zinenko // Create. 66478bd1246SAlex Zinenko PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 665b56d1ec6SPeter Hawkins nb::object pyRef = nb::cast(unownedContextWrapper); 666b56d1ec6SPeter Hawkins assert(pyRef && "cast to nb::object failed"); 66778bd1246SAlex Zinenko liveContexts[context.ptr] = unownedContextWrapper; 66878bd1246SAlex Zinenko return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 669436c6c9cSStella Laurenzo } 670436c6c9cSStella Laurenzo // Use existing. 671b56d1ec6SPeter Hawkins nb::object pyRef = nb::cast(it->second); 672436c6c9cSStella Laurenzo return PyMlirContextRef(it->second, std::move(pyRef)); 673436c6c9cSStella Laurenzo } 674436c6c9cSStella Laurenzo 675f136c800Svfdev nb::ft_mutex PyMlirContext::live_contexts_mutex; 676f136c800Svfdev 677436c6c9cSStella Laurenzo PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 678436c6c9cSStella Laurenzo static LiveContextMap liveContexts; 679436c6c9cSStella Laurenzo return liveContexts; 680436c6c9cSStella Laurenzo } 681436c6c9cSStella Laurenzo 682f136c800Svfdev size_t PyMlirContext::getLiveCount() { 683f136c800Svfdev nb::ft_lock_guard lock(live_contexts_mutex); 684f136c800Svfdev return getLiveContexts().size(); 685f136c800Svfdev } 686436c6c9cSStella Laurenzo 687e2c49a45SPeter Hawkins size_t PyMlirContext::getLiveOperationCount() { 688e2c49a45SPeter Hawkins nb::ft_lock_guard lock(liveOperationsMutex); 689e2c49a45SPeter Hawkins return liveOperations.size(); 690e2c49a45SPeter Hawkins } 691436c6c9cSStella Laurenzo 692d1fdb416SJohn Demme std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() { 693d1fdb416SJohn Demme std::vector<PyOperation *> liveObjects; 694e2c49a45SPeter Hawkins nb::ft_lock_guard lock(liveOperationsMutex); 695d1fdb416SJohn Demme for (auto &entry : liveOperations) 696d1fdb416SJohn Demme liveObjects.push_back(entry.second.second); 697d1fdb416SJohn Demme return liveObjects; 698d1fdb416SJohn Demme } 699d1fdb416SJohn Demme 7006b0bed7eSJohn Demme size_t PyMlirContext::clearLiveOperations() { 701e2c49a45SPeter Hawkins 702e2c49a45SPeter Hawkins LiveOperationMap operations; 703e2c49a45SPeter Hawkins { 704e2c49a45SPeter Hawkins nb::ft_lock_guard lock(liveOperationsMutex); 705e2c49a45SPeter Hawkins std::swap(operations, liveOperations); 706e2c49a45SPeter Hawkins } 707e2c49a45SPeter Hawkins for (auto &op : operations) 7086b0bed7eSJohn Demme op.second.second->setInvalid(); 709e2c49a45SPeter Hawkins size_t numInvalidated = operations.size(); 7106b0bed7eSJohn Demme return numInvalidated; 7116b0bed7eSJohn Demme } 7126b0bed7eSJohn Demme 713fa19ef7aSIngo Müller void PyMlirContext::clearOperation(MlirOperation op) { 714e2c49a45SPeter Hawkins PyOperation *py_op; 715e2c49a45SPeter Hawkins { 716e2c49a45SPeter Hawkins nb::ft_lock_guard lock(liveOperationsMutex); 717fa19ef7aSIngo Müller auto it = liveOperations.find(op.ptr); 718e2c49a45SPeter Hawkins if (it == liveOperations.end()) { 719e2c49a45SPeter Hawkins return; 720e2c49a45SPeter Hawkins } 721e2c49a45SPeter Hawkins py_op = it->second.second; 722fa19ef7aSIngo Müller liveOperations.erase(it); 723fa19ef7aSIngo Müller } 724e2c49a45SPeter Hawkins py_op->setInvalid(); 725fa19ef7aSIngo Müller } 726fa19ef7aSIngo Müller 727fa19ef7aSIngo Müller void PyMlirContext::clearOperationsInside(PyOperationBase &op) { 728fa19ef7aSIngo Müller typedef struct { 729fa19ef7aSIngo Müller PyOperation &rootOp; 730fa19ef7aSIngo Müller bool rootSeen; 731fa19ef7aSIngo Müller } callBackData; 732fa19ef7aSIngo Müller callBackData data{op.getOperation(), false}; 733fa19ef7aSIngo Müller // Mark all ops below the op that the passmanager will be rooted 734fa19ef7aSIngo Müller // at (but not op itself - note the preorder) as invalid. 735fa19ef7aSIngo Müller MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, 736fa19ef7aSIngo Müller void *userData) { 737fa19ef7aSIngo Müller callBackData *data = static_cast<callBackData *>(userData); 738fa19ef7aSIngo Müller if (LLVM_LIKELY(data->rootSeen)) 739fa19ef7aSIngo Müller data->rootOp.getOperation().getContext()->clearOperation(op); 740fa19ef7aSIngo Müller else 741fa19ef7aSIngo Müller data->rootSeen = true; 74247148832SHideto Ueno return MlirWalkResult::MlirWalkResultAdvance; 743fa19ef7aSIngo Müller }; 744fa19ef7aSIngo Müller mlirOperationWalk(op.getOperation(), invalidatingCallback, 745fa19ef7aSIngo Müller static_cast<void *>(&data), MlirWalkPreOrder); 746bdc3e6cbSMaksim Levental } 74791f11611SOleksandr "Alex" Zinenko void PyMlirContext::clearOperationsInside(MlirOperation op) { 74891f11611SOleksandr "Alex" Zinenko PyOperationRef opRef = PyOperation::forOperation(getRef(), op); 74991f11611SOleksandr "Alex" Zinenko clearOperationsInside(opRef->getOperation()); 75091f11611SOleksandr "Alex" Zinenko } 751bdc3e6cbSMaksim Levental 75267897d77SOleksandr "Alex" Zinenko void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { 75367897d77SOleksandr "Alex" Zinenko MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, 75467897d77SOleksandr "Alex" Zinenko void *userData) { 75567897d77SOleksandr "Alex" Zinenko PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData); 75667897d77SOleksandr "Alex" Zinenko contextRef->clearOperation(op); 75767897d77SOleksandr "Alex" Zinenko return MlirWalkResult::MlirWalkResultAdvance; 75867897d77SOleksandr "Alex" Zinenko }; 75967897d77SOleksandr "Alex" Zinenko mlirOperationWalk(op.getOperation(), invalidatingCallback, 76067897d77SOleksandr "Alex" Zinenko &op.getOperation().getContext(), MlirWalkPreOrder); 76167897d77SOleksandr "Alex" Zinenko } 76267897d77SOleksandr "Alex" Zinenko 763436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 764436c6c9cSStella Laurenzo 765b56d1ec6SPeter Hawkins nb::object PyMlirContext::contextEnter(nb::object context) { 766b56d1ec6SPeter Hawkins return PyThreadContextEntry::pushContext(context); 767436c6c9cSStella Laurenzo } 768436c6c9cSStella Laurenzo 769b56d1ec6SPeter Hawkins void PyMlirContext::contextExit(const nb::object &excType, 770b56d1ec6SPeter Hawkins const nb::object &excVal, 771b56d1ec6SPeter Hawkins const nb::object &excTb) { 772436c6c9cSStella Laurenzo PyThreadContextEntry::popContext(*this); 773436c6c9cSStella Laurenzo } 774436c6c9cSStella Laurenzo 775b56d1ec6SPeter Hawkins nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) { 7767ee25bc5SStella Laurenzo // Note that ownership is transferred to the delete callback below by way of 7777ee25bc5SStella Laurenzo // an explicit inc_ref (borrow). 7787ee25bc5SStella Laurenzo PyDiagnosticHandler *pyHandler = 7797ee25bc5SStella Laurenzo new PyDiagnosticHandler(get(), std::move(callback)); 780b56d1ec6SPeter Hawkins nb::object pyHandlerObject = 781b56d1ec6SPeter Hawkins nb::cast(pyHandler, nb::rv_policy::take_ownership); 7827ee25bc5SStella Laurenzo pyHandlerObject.inc_ref(); 7837ee25bc5SStella Laurenzo 7847ee25bc5SStella Laurenzo // In these C callbacks, the userData is a PyDiagnosticHandler* that is 7857ee25bc5SStella Laurenzo // guaranteed to be known to pybind. 7867ee25bc5SStella Laurenzo auto handlerCallback = 7877ee25bc5SStella Laurenzo +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { 7887ee25bc5SStella Laurenzo PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); 789b56d1ec6SPeter Hawkins nb::object pyDiagnosticObject = 790b56d1ec6SPeter Hawkins nb::cast(pyDiagnostic, nb::rv_policy::take_ownership); 7917ee25bc5SStella Laurenzo 7927ee25bc5SStella Laurenzo auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 7937ee25bc5SStella Laurenzo bool result = false; 7947ee25bc5SStella Laurenzo { 7957ee25bc5SStella Laurenzo // Since this can be called from arbitrary C++ contexts, always get the 7967ee25bc5SStella Laurenzo // gil. 797b56d1ec6SPeter Hawkins nb::gil_scoped_acquire gil; 7987ee25bc5SStella Laurenzo try { 799b56d1ec6SPeter Hawkins result = nb::cast<bool>(pyHandler->callback(pyDiagnostic)); 8007ee25bc5SStella Laurenzo } catch (std::exception &e) { 8017ee25bc5SStella Laurenzo fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", 8027ee25bc5SStella Laurenzo e.what()); 8037ee25bc5SStella Laurenzo pyHandler->hadError = true; 8047ee25bc5SStella Laurenzo } 8057ee25bc5SStella Laurenzo } 8067ee25bc5SStella Laurenzo 8077ee25bc5SStella Laurenzo pyDiagnostic->invalidate(); 8087ee25bc5SStella Laurenzo return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); 8097ee25bc5SStella Laurenzo }; 8107ee25bc5SStella Laurenzo auto deleteCallback = +[](void *userData) { 8117ee25bc5SStella Laurenzo auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 8127ee25bc5SStella Laurenzo assert(pyHandler->registeredID && "handler is not registered"); 8137ee25bc5SStella Laurenzo pyHandler->registeredID.reset(); 8147ee25bc5SStella Laurenzo 8157ee25bc5SStella Laurenzo // Decrement reference, balancing the inc_ref() above. 816b56d1ec6SPeter Hawkins nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference); 8177ee25bc5SStella Laurenzo pyHandlerObject.dec_ref(); 8187ee25bc5SStella Laurenzo }; 8197ee25bc5SStella Laurenzo 8207ee25bc5SStella Laurenzo pyHandler->registeredID = mlirContextAttachDiagnosticHandler( 8217ee25bc5SStella Laurenzo get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback); 8227ee25bc5SStella Laurenzo return pyHandlerObject; 8237ee25bc5SStella Laurenzo } 8247ee25bc5SStella Laurenzo 8253ea4c501SRahul Kayaith MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag, 8263ea4c501SRahul Kayaith void *userData) { 8273ea4c501SRahul Kayaith auto *self = static_cast<ErrorCapture *>(userData); 8283ea4c501SRahul Kayaith // Check if the context requested we emit errors instead of capturing them. 8293ea4c501SRahul Kayaith if (self->ctx->emitErrorDiagnostics) 8303ea4c501SRahul Kayaith return mlirLogicalResultFailure(); 8313ea4c501SRahul Kayaith 8323ea4c501SRahul Kayaith if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError) 8333ea4c501SRahul Kayaith return mlirLogicalResultFailure(); 8343ea4c501SRahul Kayaith 8353ea4c501SRahul Kayaith self->errors.emplace_back(PyDiagnostic(diag).getInfo()); 8363ea4c501SRahul Kayaith return mlirLogicalResultSuccess(); 8373ea4c501SRahul Kayaith } 8383ea4c501SRahul Kayaith 839436c6c9cSStella Laurenzo PyMlirContext &DefaultingPyMlirContext::resolve() { 840436c6c9cSStella Laurenzo PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 841436c6c9cSStella Laurenzo if (!context) { 8424811270bSmax throw std::runtime_error( 843436c6c9cSStella Laurenzo "An MLIR function requires a Context but none was provided in the call " 844436c6c9cSStella Laurenzo "or from the surrounding environment. Either pass to the function with " 845436c6c9cSStella Laurenzo "a 'context=' argument or establish a default using 'with Context():'"); 846436c6c9cSStella Laurenzo } 847436c6c9cSStella Laurenzo return *context; 848436c6c9cSStella Laurenzo } 849436c6c9cSStella Laurenzo 850436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 851436c6c9cSStella Laurenzo // PyThreadContextEntry management 852436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 853436c6c9cSStella Laurenzo 854436c6c9cSStella Laurenzo std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 855436c6c9cSStella Laurenzo static thread_local std::vector<PyThreadContextEntry> stack; 856436c6c9cSStella Laurenzo return stack; 857436c6c9cSStella Laurenzo } 858436c6c9cSStella Laurenzo 859436c6c9cSStella Laurenzo PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 860436c6c9cSStella Laurenzo auto &stack = getStack(); 861436c6c9cSStella Laurenzo if (stack.empty()) 862436c6c9cSStella Laurenzo return nullptr; 863436c6c9cSStella Laurenzo return &stack.back(); 864436c6c9cSStella Laurenzo } 865436c6c9cSStella Laurenzo 866b56d1ec6SPeter Hawkins void PyThreadContextEntry::push(FrameKind frameKind, nb::object context, 867b56d1ec6SPeter Hawkins nb::object insertionPoint, 868b56d1ec6SPeter Hawkins nb::object location) { 869436c6c9cSStella Laurenzo auto &stack = getStack(); 870436c6c9cSStella Laurenzo stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 871436c6c9cSStella Laurenzo std::move(location)); 872436c6c9cSStella Laurenzo // If the new stack has more than one entry and the context of the new top 873436c6c9cSStella Laurenzo // entry matches the previous, copy the insertionPoint and location from the 874436c6c9cSStella Laurenzo // previous entry if missing from the new top entry. 875436c6c9cSStella Laurenzo if (stack.size() > 1) { 876436c6c9cSStella Laurenzo auto &prev = *(stack.rbegin() + 1); 877436c6c9cSStella Laurenzo auto ¤t = stack.back(); 878436c6c9cSStella Laurenzo if (current.context.is(prev.context)) { 879436c6c9cSStella Laurenzo // Default non-context objects from the previous entry. 880436c6c9cSStella Laurenzo if (!current.insertionPoint) 881436c6c9cSStella Laurenzo current.insertionPoint = prev.insertionPoint; 882436c6c9cSStella Laurenzo if (!current.location) 883436c6c9cSStella Laurenzo current.location = prev.location; 884436c6c9cSStella Laurenzo } 885436c6c9cSStella Laurenzo } 886436c6c9cSStella Laurenzo } 887436c6c9cSStella Laurenzo 888436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getContext() { 889436c6c9cSStella Laurenzo if (!context) 890436c6c9cSStella Laurenzo return nullptr; 891b56d1ec6SPeter Hawkins return nb::cast<PyMlirContext *>(context); 892436c6c9cSStella Laurenzo } 893436c6c9cSStella Laurenzo 894436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 895436c6c9cSStella Laurenzo if (!insertionPoint) 896436c6c9cSStella Laurenzo return nullptr; 897b56d1ec6SPeter Hawkins return nb::cast<PyInsertionPoint *>(insertionPoint); 898436c6c9cSStella Laurenzo } 899436c6c9cSStella Laurenzo 900436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getLocation() { 901436c6c9cSStella Laurenzo if (!location) 902436c6c9cSStella Laurenzo return nullptr; 903b56d1ec6SPeter Hawkins return nb::cast<PyLocation *>(location); 904436c6c9cSStella Laurenzo } 905436c6c9cSStella Laurenzo 906436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getDefaultContext() { 907436c6c9cSStella Laurenzo auto *tos = getTopOfStack(); 908436c6c9cSStella Laurenzo return tos ? tos->getContext() : nullptr; 909436c6c9cSStella Laurenzo } 910436c6c9cSStella Laurenzo 911436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 912436c6c9cSStella Laurenzo auto *tos = getTopOfStack(); 913436c6c9cSStella Laurenzo return tos ? tos->getInsertionPoint() : nullptr; 914436c6c9cSStella Laurenzo } 915436c6c9cSStella Laurenzo 916436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getDefaultLocation() { 917436c6c9cSStella Laurenzo auto *tos = getTopOfStack(); 918436c6c9cSStella Laurenzo return tos ? tos->getLocation() : nullptr; 919436c6c9cSStella Laurenzo } 920436c6c9cSStella Laurenzo 921b56d1ec6SPeter Hawkins nb::object PyThreadContextEntry::pushContext(nb::object context) { 922b56d1ec6SPeter Hawkins push(FrameKind::Context, /*context=*/context, 923b56d1ec6SPeter Hawkins /*insertionPoint=*/nb::object(), 924b56d1ec6SPeter Hawkins /*location=*/nb::object()); 925b56d1ec6SPeter Hawkins return context; 926436c6c9cSStella Laurenzo } 927436c6c9cSStella Laurenzo 928436c6c9cSStella Laurenzo void PyThreadContextEntry::popContext(PyMlirContext &context) { 929436c6c9cSStella Laurenzo auto &stack = getStack(); 930436c6c9cSStella Laurenzo if (stack.empty()) 9314811270bSmax throw std::runtime_error("Unbalanced Context enter/exit"); 932436c6c9cSStella Laurenzo auto &tos = stack.back(); 933436c6c9cSStella Laurenzo if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 9344811270bSmax throw std::runtime_error("Unbalanced Context enter/exit"); 935436c6c9cSStella Laurenzo stack.pop_back(); 936436c6c9cSStella Laurenzo } 937436c6c9cSStella Laurenzo 938b56d1ec6SPeter Hawkins nb::object 939b56d1ec6SPeter Hawkins PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) { 940b56d1ec6SPeter Hawkins PyInsertionPoint &insertionPoint = 941b56d1ec6SPeter Hawkins nb::cast<PyInsertionPoint &>(insertionPointObj); 942b56d1ec6SPeter Hawkins nb::object contextObj = 943436c6c9cSStella Laurenzo insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 944436c6c9cSStella Laurenzo push(FrameKind::InsertionPoint, 945436c6c9cSStella Laurenzo /*context=*/contextObj, 946436c6c9cSStella Laurenzo /*insertionPoint=*/insertionPointObj, 947b56d1ec6SPeter Hawkins /*location=*/nb::object()); 948436c6c9cSStella Laurenzo return insertionPointObj; 949436c6c9cSStella Laurenzo } 950436c6c9cSStella Laurenzo 951436c6c9cSStella Laurenzo void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 952436c6c9cSStella Laurenzo auto &stack = getStack(); 953436c6c9cSStella Laurenzo if (stack.empty()) 9544811270bSmax throw std::runtime_error("Unbalanced InsertionPoint enter/exit"); 955436c6c9cSStella Laurenzo auto &tos = stack.back(); 956436c6c9cSStella Laurenzo if (tos.frameKind != FrameKind::InsertionPoint && 957436c6c9cSStella Laurenzo tos.getInsertionPoint() != &insertionPoint) 9584811270bSmax throw std::runtime_error("Unbalanced InsertionPoint enter/exit"); 959436c6c9cSStella Laurenzo stack.pop_back(); 960436c6c9cSStella Laurenzo } 961436c6c9cSStella Laurenzo 962b56d1ec6SPeter Hawkins nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) { 963b56d1ec6SPeter Hawkins PyLocation &location = nb::cast<PyLocation &>(locationObj); 964b56d1ec6SPeter Hawkins nb::object contextObj = location.getContext().getObject(); 965436c6c9cSStella Laurenzo push(FrameKind::Location, /*context=*/contextObj, 966b56d1ec6SPeter Hawkins /*insertionPoint=*/nb::object(), 967436c6c9cSStella Laurenzo /*location=*/locationObj); 968436c6c9cSStella Laurenzo return locationObj; 969436c6c9cSStella Laurenzo } 970436c6c9cSStella Laurenzo 971436c6c9cSStella Laurenzo void PyThreadContextEntry::popLocation(PyLocation &location) { 972436c6c9cSStella Laurenzo auto &stack = getStack(); 973436c6c9cSStella Laurenzo if (stack.empty()) 9744811270bSmax throw std::runtime_error("Unbalanced Location enter/exit"); 975436c6c9cSStella Laurenzo auto &tos = stack.back(); 976436c6c9cSStella Laurenzo if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 9774811270bSmax throw std::runtime_error("Unbalanced Location enter/exit"); 978436c6c9cSStella Laurenzo stack.pop_back(); 979436c6c9cSStella Laurenzo } 980436c6c9cSStella Laurenzo 981436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 9827ee25bc5SStella Laurenzo // PyDiagnostic* 9837ee25bc5SStella Laurenzo //------------------------------------------------------------------------------ 9847ee25bc5SStella Laurenzo 9857ee25bc5SStella Laurenzo void PyDiagnostic::invalidate() { 9867ee25bc5SStella Laurenzo valid = false; 9877ee25bc5SStella Laurenzo if (materializedNotes) { 988b56d1ec6SPeter Hawkins for (nb::handle noteObject : *materializedNotes) { 989b56d1ec6SPeter Hawkins PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject); 9907ee25bc5SStella Laurenzo note->invalidate(); 9917ee25bc5SStella Laurenzo } 9927ee25bc5SStella Laurenzo } 9937ee25bc5SStella Laurenzo } 9947ee25bc5SStella Laurenzo 9957ee25bc5SStella Laurenzo PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, 996b56d1ec6SPeter Hawkins nb::object callback) 9977ee25bc5SStella Laurenzo : context(context), callback(std::move(callback)) {} 9987ee25bc5SStella Laurenzo 9996a38cbfbSMehdi Amini PyDiagnosticHandler::~PyDiagnosticHandler() = default; 10007ee25bc5SStella Laurenzo 10017ee25bc5SStella Laurenzo void PyDiagnosticHandler::detach() { 10027ee25bc5SStella Laurenzo if (!registeredID) 10037ee25bc5SStella Laurenzo return; 10047ee25bc5SStella Laurenzo MlirDiagnosticHandlerID localID = *registeredID; 10057ee25bc5SStella Laurenzo mlirContextDetachDiagnosticHandler(context, localID); 10067ee25bc5SStella Laurenzo assert(!registeredID && "should have unregistered"); 10077ee25bc5SStella Laurenzo // Not strictly necessary but keeps stale pointers from being around to cause 10087ee25bc5SStella Laurenzo // issues. 10097ee25bc5SStella Laurenzo context = {nullptr}; 10107ee25bc5SStella Laurenzo } 10117ee25bc5SStella Laurenzo 10127ee25bc5SStella Laurenzo void PyDiagnostic::checkValid() { 10137ee25bc5SStella Laurenzo if (!valid) { 10147ee25bc5SStella Laurenzo throw std::invalid_argument( 10157ee25bc5SStella Laurenzo "Diagnostic is invalid (used outside of callback)"); 10167ee25bc5SStella Laurenzo } 10177ee25bc5SStella Laurenzo } 10187ee25bc5SStella Laurenzo 10197ee25bc5SStella Laurenzo MlirDiagnosticSeverity PyDiagnostic::getSeverity() { 10207ee25bc5SStella Laurenzo checkValid(); 10217ee25bc5SStella Laurenzo return mlirDiagnosticGetSeverity(diagnostic); 10227ee25bc5SStella Laurenzo } 10237ee25bc5SStella Laurenzo 10247ee25bc5SStella Laurenzo PyLocation PyDiagnostic::getLocation() { 10257ee25bc5SStella Laurenzo checkValid(); 10267ee25bc5SStella Laurenzo MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); 10277ee25bc5SStella Laurenzo MlirContext context = mlirLocationGetContext(loc); 10287ee25bc5SStella Laurenzo return PyLocation(PyMlirContext::forContext(context), loc); 10297ee25bc5SStella Laurenzo } 10307ee25bc5SStella Laurenzo 1031b56d1ec6SPeter Hawkins nb::str PyDiagnostic::getMessage() { 10327ee25bc5SStella Laurenzo checkValid(); 1033b56d1ec6SPeter Hawkins nb::object fileObject = nb::module_::import_("io").attr("StringIO")(); 10347ee25bc5SStella Laurenzo PyFileAccumulator accum(fileObject, /*binary=*/false); 10357ee25bc5SStella Laurenzo mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); 1036b56d1ec6SPeter Hawkins return nb::cast<nb::str>(fileObject.attr("getvalue")()); 10377ee25bc5SStella Laurenzo } 10387ee25bc5SStella Laurenzo 1039b56d1ec6SPeter Hawkins nb::tuple PyDiagnostic::getNotes() { 10407ee25bc5SStella Laurenzo checkValid(); 10417ee25bc5SStella Laurenzo if (materializedNotes) 10427ee25bc5SStella Laurenzo return *materializedNotes; 10437ee25bc5SStella Laurenzo intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); 1044b56d1ec6SPeter Hawkins nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes)); 10457ee25bc5SStella Laurenzo for (intptr_t i = 0; i < numNotes; ++i) { 10467ee25bc5SStella Laurenzo MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); 1047b56d1ec6SPeter Hawkins nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag)); 1048b56d1ec6SPeter Hawkins PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr()); 10497ee25bc5SStella Laurenzo } 1050b56d1ec6SPeter Hawkins materializedNotes = std::move(notes); 1051b56d1ec6SPeter Hawkins 10527ee25bc5SStella Laurenzo return *materializedNotes; 10537ee25bc5SStella Laurenzo } 10547ee25bc5SStella Laurenzo 10553ea4c501SRahul Kayaith PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() { 10563ea4c501SRahul Kayaith std::vector<DiagnosticInfo> notes; 1057b56d1ec6SPeter Hawkins for (nb::handle n : getNotes()) 1058b56d1ec6SPeter Hawkins notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo()); 1059b56d1ec6SPeter Hawkins return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()), 1060b56d1ec6SPeter Hawkins std::move(notes)}; 10613ea4c501SRahul Kayaith } 10623ea4c501SRahul Kayaith 10637ee25bc5SStella Laurenzo //------------------------------------------------------------------------------ 10645e83a5b4SStella Laurenzo // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry 1065436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1066436c6c9cSStella Laurenzo 1067436c6c9cSStella Laurenzo MlirDialect PyDialects::getDialectForKey(const std::string &key, 1068436c6c9cSStella Laurenzo bool attrError) { 1069f8479d9dSRiver Riddle MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 1070f8479d9dSRiver Riddle {key.data(), key.size()}); 1071436c6c9cSStella Laurenzo if (mlirDialectIsNull(dialect)) { 10724811270bSmax std::string msg = (Twine("Dialect '") + key + "' not found").str(); 10734811270bSmax if (attrError) 1074b56d1ec6SPeter Hawkins throw nb::attribute_error(msg.c_str()); 1075b56d1ec6SPeter Hawkins throw nb::index_error(msg.c_str()); 1076436c6c9cSStella Laurenzo } 1077436c6c9cSStella Laurenzo return dialect; 1078436c6c9cSStella Laurenzo } 1079436c6c9cSStella Laurenzo 1080b56d1ec6SPeter Hawkins nb::object PyDialectRegistry::getCapsule() { 1081b56d1ec6SPeter Hawkins return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this)); 10825e83a5b4SStella Laurenzo } 10835e83a5b4SStella Laurenzo 1084b56d1ec6SPeter Hawkins PyDialectRegistry PyDialectRegistry::createFromCapsule(nb::object capsule) { 10855e83a5b4SStella Laurenzo MlirDialectRegistry rawRegistry = 10865e83a5b4SStella Laurenzo mlirPythonCapsuleToDialectRegistry(capsule.ptr()); 10875e83a5b4SStella Laurenzo if (mlirDialectRegistryIsNull(rawRegistry)) 1088b56d1ec6SPeter Hawkins throw nb::python_error(); 10895e83a5b4SStella Laurenzo return PyDialectRegistry(rawRegistry); 10905e83a5b4SStella Laurenzo } 10915e83a5b4SStella Laurenzo 1092436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1093436c6c9cSStella Laurenzo // PyLocation 1094436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1095436c6c9cSStella Laurenzo 1096b56d1ec6SPeter Hawkins nb::object PyLocation::getCapsule() { 1097b56d1ec6SPeter Hawkins return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this)); 1098436c6c9cSStella Laurenzo } 1099436c6c9cSStella Laurenzo 1100b56d1ec6SPeter Hawkins PyLocation PyLocation::createFromCapsule(nb::object capsule) { 1101436c6c9cSStella Laurenzo MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 1102436c6c9cSStella Laurenzo if (mlirLocationIsNull(rawLoc)) 1103b56d1ec6SPeter Hawkins throw nb::python_error(); 1104436c6c9cSStella Laurenzo return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 1105436c6c9cSStella Laurenzo rawLoc); 1106436c6c9cSStella Laurenzo } 1107436c6c9cSStella Laurenzo 1108b56d1ec6SPeter Hawkins nb::object PyLocation::contextEnter(nb::object locationObj) { 1109b56d1ec6SPeter Hawkins return PyThreadContextEntry::pushLocation(locationObj); 1110436c6c9cSStella Laurenzo } 1111436c6c9cSStella Laurenzo 1112b56d1ec6SPeter Hawkins void PyLocation::contextExit(const nb::object &excType, 1113b56d1ec6SPeter Hawkins const nb::object &excVal, 1114b56d1ec6SPeter Hawkins const nb::object &excTb) { 1115436c6c9cSStella Laurenzo PyThreadContextEntry::popLocation(*this); 1116436c6c9cSStella Laurenzo } 1117436c6c9cSStella Laurenzo 1118436c6c9cSStella Laurenzo PyLocation &DefaultingPyLocation::resolve() { 1119436c6c9cSStella Laurenzo auto *location = PyThreadContextEntry::getDefaultLocation(); 1120436c6c9cSStella Laurenzo if (!location) { 11214811270bSmax throw std::runtime_error( 1122436c6c9cSStella Laurenzo "An MLIR function requires a Location but none was provided in the " 1123436c6c9cSStella Laurenzo "call or from the surrounding environment. Either pass to the function " 1124436c6c9cSStella Laurenzo "with a 'loc=' argument or establish a default using 'with loc:'"); 1125436c6c9cSStella Laurenzo } 1126436c6c9cSStella Laurenzo return *location; 1127436c6c9cSStella Laurenzo } 1128436c6c9cSStella Laurenzo 1129436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1130436c6c9cSStella Laurenzo // PyModule 1131436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1132436c6c9cSStella Laurenzo 1133436c6c9cSStella Laurenzo PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 1134436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), module(module) {} 1135436c6c9cSStella Laurenzo 1136436c6c9cSStella Laurenzo PyModule::~PyModule() { 1137b56d1ec6SPeter Hawkins nb::gil_scoped_acquire acquire; 1138436c6c9cSStella Laurenzo auto &liveModules = getContext()->liveModules; 1139436c6c9cSStella Laurenzo assert(liveModules.count(module.ptr) == 1 && 1140436c6c9cSStella Laurenzo "destroying module not in live map"); 1141436c6c9cSStella Laurenzo liveModules.erase(module.ptr); 1142436c6c9cSStella Laurenzo mlirModuleDestroy(module); 1143436c6c9cSStella Laurenzo } 1144436c6c9cSStella Laurenzo 1145436c6c9cSStella Laurenzo PyModuleRef PyModule::forModule(MlirModule module) { 1146436c6c9cSStella Laurenzo MlirContext context = mlirModuleGetContext(module); 1147436c6c9cSStella Laurenzo PyMlirContextRef contextRef = PyMlirContext::forContext(context); 1148436c6c9cSStella Laurenzo 1149b56d1ec6SPeter Hawkins nb::gil_scoped_acquire acquire; 1150436c6c9cSStella Laurenzo auto &liveModules = contextRef->liveModules; 1151436c6c9cSStella Laurenzo auto it = liveModules.find(module.ptr); 1152436c6c9cSStella Laurenzo if (it == liveModules.end()) { 1153436c6c9cSStella Laurenzo // Create. 1154436c6c9cSStella Laurenzo PyModule *unownedModule = new PyModule(std::move(contextRef), module); 1155436c6c9cSStella Laurenzo // Note that the default return value policy on cast is automatic_reference, 1156436c6c9cSStella Laurenzo // which does not take ownership (delete will not be called). 1157436c6c9cSStella Laurenzo // Just be explicit. 1158b56d1ec6SPeter Hawkins nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); 1159436c6c9cSStella Laurenzo unownedModule->handle = pyRef; 1160436c6c9cSStella Laurenzo liveModules[module.ptr] = 1161436c6c9cSStella Laurenzo std::make_pair(unownedModule->handle, unownedModule); 1162436c6c9cSStella Laurenzo return PyModuleRef(unownedModule, std::move(pyRef)); 1163436c6c9cSStella Laurenzo } 1164436c6c9cSStella Laurenzo // Use existing. 1165436c6c9cSStella Laurenzo PyModule *existing = it->second.second; 1166b56d1ec6SPeter Hawkins nb::object pyRef = nb::borrow<nb::object>(it->second.first); 1167436c6c9cSStella Laurenzo return PyModuleRef(existing, std::move(pyRef)); 1168436c6c9cSStella Laurenzo } 1169436c6c9cSStella Laurenzo 1170b56d1ec6SPeter Hawkins nb::object PyModule::createFromCapsule(nb::object capsule) { 1171436c6c9cSStella Laurenzo MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 1172436c6c9cSStella Laurenzo if (mlirModuleIsNull(rawModule)) 1173b56d1ec6SPeter Hawkins throw nb::python_error(); 1174436c6c9cSStella Laurenzo return forModule(rawModule).releaseObject(); 1175436c6c9cSStella Laurenzo } 1176436c6c9cSStella Laurenzo 1177b56d1ec6SPeter Hawkins nb::object PyModule::getCapsule() { 1178b56d1ec6SPeter Hawkins return nb::steal<nb::object>(mlirPythonModuleToCapsule(get())); 1179436c6c9cSStella Laurenzo } 1180436c6c9cSStella Laurenzo 1181436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1182436c6c9cSStella Laurenzo // PyOperation 1183436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1184436c6c9cSStella Laurenzo 1185436c6c9cSStella Laurenzo PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 1186436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), operation(operation) {} 1187436c6c9cSStella Laurenzo 1188436c6c9cSStella Laurenzo PyOperation::~PyOperation() { 118949745f87SMike Urbach // If the operation has already been invalidated there is nothing to do. 119049745f87SMike Urbach if (!valid) 119149745f87SMike Urbach return; 119267897d77SOleksandr "Alex" Zinenko 119367897d77SOleksandr "Alex" Zinenko // Otherwise, invalidate the operation and remove it from live map when it is 119467897d77SOleksandr "Alex" Zinenko // attached. 119567897d77SOleksandr "Alex" Zinenko if (isAttached()) { 119667897d77SOleksandr "Alex" Zinenko getContext()->clearOperation(*this); 119767897d77SOleksandr "Alex" Zinenko } else { 119867897d77SOleksandr "Alex" Zinenko // And destroy it when it is detached, i.e. owned by Python, in which case 119967897d77SOleksandr "Alex" Zinenko // all nested operations must be invalidated at removed from the live map as 120067897d77SOleksandr "Alex" Zinenko // well. 120167897d77SOleksandr "Alex" Zinenko erase(); 1202436c6c9cSStella Laurenzo } 1203436c6c9cSStella Laurenzo } 1204436c6c9cSStella Laurenzo 1205e30b7030SPeter Hawkins namespace { 1206e30b7030SPeter Hawkins 1207e30b7030SPeter Hawkins // Constructs a new object of type T in-place on the Python heap, returning a 1208e30b7030SPeter Hawkins // PyObjectRef to it, loosely analogous to std::make_shared<T>(). 1209e30b7030SPeter Hawkins template <typename T, class... Args> 1210e30b7030SPeter Hawkins PyObjectRef<T> makeObjectRef(Args &&...args) { 1211e30b7030SPeter Hawkins nb::handle type = nb::type<T>(); 1212e30b7030SPeter Hawkins nb::object instance = nb::inst_alloc(type); 1213e30b7030SPeter Hawkins T *ptr = nb::inst_ptr<T>(instance); 1214e30b7030SPeter Hawkins new (ptr) T(std::forward<Args>(args)...); 1215e30b7030SPeter Hawkins nb::inst_mark_ready(instance); 1216e30b7030SPeter Hawkins return PyObjectRef<T>(ptr, std::move(instance)); 1217e30b7030SPeter Hawkins } 1218e30b7030SPeter Hawkins 1219e30b7030SPeter Hawkins } // namespace 1220e30b7030SPeter Hawkins 1221436c6c9cSStella Laurenzo PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 1222436c6c9cSStella Laurenzo MlirOperation operation, 1223b56d1ec6SPeter Hawkins nb::object parentKeepAlive) { 1224436c6c9cSStella Laurenzo // Create. 1225e30b7030SPeter Hawkins PyOperationRef unownedOperation = 1226e30b7030SPeter Hawkins makeObjectRef<PyOperation>(std::move(contextRef), operation); 1227e30b7030SPeter Hawkins unownedOperation->handle = unownedOperation.getObject(); 1228436c6c9cSStella Laurenzo if (parentKeepAlive) { 1229436c6c9cSStella Laurenzo unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 1230436c6c9cSStella Laurenzo } 1231e30b7030SPeter Hawkins return unownedOperation; 1232436c6c9cSStella Laurenzo } 1233436c6c9cSStella Laurenzo 1234436c6c9cSStella Laurenzo PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 1235436c6c9cSStella Laurenzo MlirOperation operation, 1236b56d1ec6SPeter Hawkins nb::object parentKeepAlive) { 1237e2c49a45SPeter Hawkins nb::ft_lock_guard lock(contextRef->liveOperationsMutex); 1238436c6c9cSStella Laurenzo auto &liveOperations = contextRef->liveOperations; 1239436c6c9cSStella Laurenzo auto it = liveOperations.find(operation.ptr); 1240436c6c9cSStella Laurenzo if (it == liveOperations.end()) { 1241436c6c9cSStella Laurenzo // Create. 1242e2c49a45SPeter Hawkins PyOperationRef result = createInstance(std::move(contextRef), operation, 1243436c6c9cSStella Laurenzo std::move(parentKeepAlive)); 1244e2c49a45SPeter Hawkins liveOperations[operation.ptr] = 1245e2c49a45SPeter Hawkins std::make_pair(result.getObject(), result.get()); 1246e2c49a45SPeter Hawkins return result; 1247436c6c9cSStella Laurenzo } 1248436c6c9cSStella Laurenzo // Use existing. 1249436c6c9cSStella Laurenzo PyOperation *existing = it->second.second; 1250b56d1ec6SPeter Hawkins nb::object pyRef = nb::borrow<nb::object>(it->second.first); 1251436c6c9cSStella Laurenzo return PyOperationRef(existing, std::move(pyRef)); 1252436c6c9cSStella Laurenzo } 1253436c6c9cSStella Laurenzo 1254436c6c9cSStella Laurenzo PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 1255436c6c9cSStella Laurenzo MlirOperation operation, 1256b56d1ec6SPeter Hawkins nb::object parentKeepAlive) { 1257e2c49a45SPeter Hawkins nb::ft_lock_guard lock(contextRef->liveOperationsMutex); 1258436c6c9cSStella Laurenzo auto &liveOperations = contextRef->liveOperations; 1259436c6c9cSStella Laurenzo assert(liveOperations.count(operation.ptr) == 0 && 1260436c6c9cSStella Laurenzo "cannot create detached operation that already exists"); 1261436c6c9cSStella Laurenzo (void)liveOperations; 1262436c6c9cSStella Laurenzo PyOperationRef created = createInstance(std::move(contextRef), operation, 1263436c6c9cSStella Laurenzo std::move(parentKeepAlive)); 1264e2c49a45SPeter Hawkins liveOperations[operation.ptr] = 1265e2c49a45SPeter Hawkins std::make_pair(created.getObject(), created.get()); 1266436c6c9cSStella Laurenzo created->attached = false; 1267436c6c9cSStella Laurenzo return created; 1268436c6c9cSStella Laurenzo } 1269436c6c9cSStella Laurenzo 127037107e17Srkayaith PyOperationRef PyOperation::parse(PyMlirContextRef contextRef, 127137107e17Srkayaith const std::string &sourceStr, 127237107e17Srkayaith const std::string &sourceName) { 12733ea4c501SRahul Kayaith PyMlirContext::ErrorCapture errors(contextRef); 127437107e17Srkayaith MlirOperation op = 127537107e17Srkayaith mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr), 127637107e17Srkayaith toMlirStringRef(sourceName)); 127737107e17Srkayaith if (mlirOperationIsNull(op)) 12783ea4c501SRahul Kayaith throw MLIRError("Unable to parse operation assembly", errors.take()); 127937107e17Srkayaith return PyOperation::createDetached(std::move(contextRef), op); 128037107e17Srkayaith } 128137107e17Srkayaith 1282436c6c9cSStella Laurenzo void PyOperation::checkValid() const { 1283436c6c9cSStella Laurenzo if (!valid) { 12844811270bSmax throw std::runtime_error("the operation has been invalidated"); 1285436c6c9cSStella Laurenzo } 1286436c6c9cSStella Laurenzo } 1287436c6c9cSStella Laurenzo 1288204acc5cSJacques Pienaar void PyOperationBase::print(std::optional<int64_t> largeElementsLimit, 1289436c6c9cSStella Laurenzo bool enableDebugInfo, bool prettyDebugInfo, 1290ace1d0adSStella Laurenzo bool printGenericOpForm, bool useLocalScope, 1291b56d1ec6SPeter Hawkins bool assumeVerified, nb::object fileObject, 1292abad8455SJonas Rickert bool binary, bool skipRegions) { 1293436c6c9cSStella Laurenzo PyOperation &operation = getOperation(); 1294436c6c9cSStella Laurenzo operation.checkValid(); 1295436c6c9cSStella Laurenzo if (fileObject.is_none()) 1296b56d1ec6SPeter Hawkins fileObject = nb::module_::import_("sys").attr("stdout"); 1297436c6c9cSStella Laurenzo 1298436c6c9cSStella Laurenzo MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 1299436c6c9cSStella Laurenzo if (largeElementsLimit) 1300436c6c9cSStella Laurenzo mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 1301436c6c9cSStella Laurenzo if (enableDebugInfo) 1302d0236611SRiver Riddle mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, 1303d0236611SRiver Riddle /*prettyForm=*/prettyDebugInfo); 1304436c6c9cSStella Laurenzo if (printGenericOpForm) 1305436c6c9cSStella Laurenzo mlirOpPrintingFlagsPrintGenericOpForm(flags); 1306bccf27d9SMark Browning if (useLocalScope) 1307bccf27d9SMark Browning mlirOpPrintingFlagsUseLocalScope(flags); 13082aa12583SRahul Kayaith if (assumeVerified) 13092aa12583SRahul Kayaith mlirOpPrintingFlagsAssumeVerified(flags); 1310abad8455SJonas Rickert if (skipRegions) 1311abad8455SJonas Rickert mlirOpPrintingFlagsSkipRegions(flags); 1312436c6c9cSStella Laurenzo 1313436c6c9cSStella Laurenzo PyFileAccumulator accum(fileObject, binary); 1314436c6c9cSStella Laurenzo mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 1315436c6c9cSStella Laurenzo accum.getUserData()); 1316436c6c9cSStella Laurenzo mlirOpPrintingFlagsDestroy(flags); 1317436c6c9cSStella Laurenzo } 1318436c6c9cSStella Laurenzo 1319b56d1ec6SPeter Hawkins void PyOperationBase::print(PyAsmState &state, nb::object fileObject, 1320204acc5cSJacques Pienaar bool binary) { 1321204acc5cSJacques Pienaar PyOperation &operation = getOperation(); 1322204acc5cSJacques Pienaar operation.checkValid(); 1323204acc5cSJacques Pienaar if (fileObject.is_none()) 1324b56d1ec6SPeter Hawkins fileObject = nb::module_::import_("sys").attr("stdout"); 1325204acc5cSJacques Pienaar PyFileAccumulator accum(fileObject, binary); 1326204acc5cSJacques Pienaar mlirOperationPrintWithState(operation, state.get(), accum.getCallback(), 1327204acc5cSJacques Pienaar accum.getUserData()); 1328204acc5cSJacques Pienaar } 1329204acc5cSJacques Pienaar 1330b56d1ec6SPeter Hawkins void PyOperationBase::writeBytecode(const nb::object &fileObject, 13310610e2f6SJacques Pienaar std::optional<int64_t> bytecodeVersion) { 133289418ddcSMehdi Amini PyOperation &operation = getOperation(); 133389418ddcSMehdi Amini operation.checkValid(); 133489418ddcSMehdi Amini PyFileAccumulator accum(fileObject, /*binary=*/true); 13350610e2f6SJacques Pienaar 13360610e2f6SJacques Pienaar if (!bytecodeVersion.has_value()) 13370610e2f6SJacques Pienaar return mlirOperationWriteBytecode(operation, accum.getCallback(), 133889418ddcSMehdi Amini accum.getUserData()); 13390610e2f6SJacques Pienaar 13400610e2f6SJacques Pienaar MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate(); 13410610e2f6SJacques Pienaar mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion); 13425c90e1ffSJacques Pienaar MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig( 13430610e2f6SJacques Pienaar operation, config, accum.getCallback(), accum.getUserData()); 13449816cc91SAdam Paszke mlirBytecodeWriterConfigDestroy(config); 13455c90e1ffSJacques Pienaar if (mlirLogicalResultIsFailure(res)) 1346b56d1ec6SPeter Hawkins throw nb::value_error((Twine("Unable to honor desired bytecode version ") + 13475c90e1ffSJacques Pienaar Twine(*bytecodeVersion)) 1348b56d1ec6SPeter Hawkins .str() 1349b56d1ec6SPeter Hawkins .c_str()); 135089418ddcSMehdi Amini } 135189418ddcSMehdi Amini 135247148832SHideto Ueno void PyOperationBase::walk( 135347148832SHideto Ueno std::function<MlirWalkResult(MlirOperation)> callback, 135447148832SHideto Ueno MlirWalkOrder walkOrder) { 135547148832SHideto Ueno PyOperation &operation = getOperation(); 135647148832SHideto Ueno operation.checkValid(); 1357bc553646Stomnatan30 struct UserData { 1358bc553646Stomnatan30 std::function<MlirWalkResult(MlirOperation)> callback; 1359bc553646Stomnatan30 bool gotException; 1360bc553646Stomnatan30 std::string exceptionWhat; 1361b56d1ec6SPeter Hawkins nb::object exceptionType; 1362bc553646Stomnatan30 }; 1363bc553646Stomnatan30 UserData userData{callback, false, {}, {}}; 136447148832SHideto Ueno MlirOperationWalkCallback walkCallback = [](MlirOperation op, 136547148832SHideto Ueno void *userData) { 1366bc553646Stomnatan30 UserData *calleeUserData = static_cast<UserData *>(userData); 1367bc553646Stomnatan30 try { 1368bc553646Stomnatan30 return (calleeUserData->callback)(op); 1369b56d1ec6SPeter Hawkins } catch (nb::python_error &e) { 1370bc553646Stomnatan30 calleeUserData->gotException = true; 1371b56d1ec6SPeter Hawkins calleeUserData->exceptionWhat = std::string(e.what()); 1372b56d1ec6SPeter Hawkins calleeUserData->exceptionType = nb::borrow(e.type()); 1373bc553646Stomnatan30 return MlirWalkResult::MlirWalkResultInterrupt; 1374bc553646Stomnatan30 } 137547148832SHideto Ueno }; 1376bc553646Stomnatan30 mlirOperationWalk(operation, walkCallback, &userData, walkOrder); 1377bc553646Stomnatan30 if (userData.gotException) { 1378bc553646Stomnatan30 std::string message("Exception raised in callback: "); 1379bc553646Stomnatan30 message.append(userData.exceptionWhat); 1380bc553646Stomnatan30 throw std::runtime_error(message); 1381bc553646Stomnatan30 } 138247148832SHideto Ueno } 138347148832SHideto Ueno 1384b56d1ec6SPeter Hawkins nb::object PyOperationBase::getAsm(bool binary, 13850a81ace0SKazu Hirata std::optional<int64_t> largeElementsLimit, 1386436c6c9cSStella Laurenzo bool enableDebugInfo, bool prettyDebugInfo, 1387ace1d0adSStella Laurenzo bool printGenericOpForm, bool useLocalScope, 1388abad8455SJonas Rickert bool assumeVerified, bool skipRegions) { 1389b56d1ec6SPeter Hawkins nb::object fileObject; 1390436c6c9cSStella Laurenzo if (binary) { 1391b56d1ec6SPeter Hawkins fileObject = nb::module_::import_("io").attr("BytesIO")(); 1392436c6c9cSStella Laurenzo } else { 1393b56d1ec6SPeter Hawkins fileObject = nb::module_::import_("io").attr("StringIO")(); 1394436c6c9cSStella Laurenzo } 1395204acc5cSJacques Pienaar print(/*largeElementsLimit=*/largeElementsLimit, 1396436c6c9cSStella Laurenzo /*enableDebugInfo=*/enableDebugInfo, 1397436c6c9cSStella Laurenzo /*prettyDebugInfo=*/prettyDebugInfo, 1398436c6c9cSStella Laurenzo /*printGenericOpForm=*/printGenericOpForm, 1399ace1d0adSStella Laurenzo /*useLocalScope=*/useLocalScope, 1400204acc5cSJacques Pienaar /*assumeVerified=*/assumeVerified, 1401204acc5cSJacques Pienaar /*fileObject=*/fileObject, 1402abad8455SJonas Rickert /*binary=*/binary, 1403abad8455SJonas Rickert /*skipRegions=*/skipRegions); 1404436c6c9cSStella Laurenzo 1405436c6c9cSStella Laurenzo return fileObject.attr("getvalue")(); 1406436c6c9cSStella Laurenzo } 1407436c6c9cSStella Laurenzo 140824685aaeSAlex Zinenko void PyOperationBase::moveAfter(PyOperationBase &other) { 140924685aaeSAlex Zinenko PyOperation &operation = getOperation(); 141024685aaeSAlex Zinenko PyOperation &otherOp = other.getOperation(); 141124685aaeSAlex Zinenko operation.checkValid(); 141224685aaeSAlex Zinenko otherOp.checkValid(); 141324685aaeSAlex Zinenko mlirOperationMoveAfter(operation, otherOp); 141424685aaeSAlex Zinenko operation.parentKeepAlive = otherOp.parentKeepAlive; 141524685aaeSAlex Zinenko } 141624685aaeSAlex Zinenko 141724685aaeSAlex Zinenko void PyOperationBase::moveBefore(PyOperationBase &other) { 141824685aaeSAlex Zinenko PyOperation &operation = getOperation(); 141924685aaeSAlex Zinenko PyOperation &otherOp = other.getOperation(); 142024685aaeSAlex Zinenko operation.checkValid(); 142124685aaeSAlex Zinenko otherOp.checkValid(); 142224685aaeSAlex Zinenko mlirOperationMoveBefore(operation, otherOp); 142324685aaeSAlex Zinenko operation.parentKeepAlive = otherOp.parentKeepAlive; 142424685aaeSAlex Zinenko } 142524685aaeSAlex Zinenko 14263ea4c501SRahul Kayaith bool PyOperationBase::verify() { 14273ea4c501SRahul Kayaith PyOperation &op = getOperation(); 14283ea4c501SRahul Kayaith PyMlirContext::ErrorCapture errors(op.getContext()); 14293ea4c501SRahul Kayaith if (!mlirOperationVerify(op.get())) 14303ea4c501SRahul Kayaith throw MLIRError("Verification failed", errors.take()); 14313ea4c501SRahul Kayaith return true; 14323ea4c501SRahul Kayaith } 14333ea4c501SRahul Kayaith 14340a81ace0SKazu Hirata std::optional<PyOperationRef> PyOperation::getParentOperation() { 143549745f87SMike Urbach checkValid(); 1436436c6c9cSStella Laurenzo if (!isAttached()) 1437b56d1ec6SPeter Hawkins throw nb::value_error("Detached operations have no parent"); 1438436c6c9cSStella Laurenzo MlirOperation operation = mlirOperationGetParentOperation(get()); 1439436c6c9cSStella Laurenzo if (mlirOperationIsNull(operation)) 14401689dadeSJohn Demme return {}; 1441436c6c9cSStella Laurenzo return PyOperation::forOperation(getContext(), operation); 1442436c6c9cSStella Laurenzo } 1443436c6c9cSStella Laurenzo 1444436c6c9cSStella Laurenzo PyBlock PyOperation::getBlock() { 144549745f87SMike Urbach checkValid(); 14460a81ace0SKazu Hirata std::optional<PyOperationRef> parentOperation = getParentOperation(); 1447436c6c9cSStella Laurenzo MlirBlock block = mlirOperationGetBlock(get()); 1448436c6c9cSStella Laurenzo assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 14491689dadeSJohn Demme assert(parentOperation && "Operation has no parent"); 14501689dadeSJohn Demme return PyBlock{std::move(*parentOperation), block}; 1451436c6c9cSStella Laurenzo } 1452436c6c9cSStella Laurenzo 1453b56d1ec6SPeter Hawkins nb::object PyOperation::getCapsule() { 145449745f87SMike Urbach checkValid(); 1455b56d1ec6SPeter Hawkins return nb::steal<nb::object>(mlirPythonOperationToCapsule(get())); 14560126e906SJohn Demme } 14570126e906SJohn Demme 1458b56d1ec6SPeter Hawkins nb::object PyOperation::createFromCapsule(nb::object capsule) { 14590126e906SJohn Demme MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 14600126e906SJohn Demme if (mlirOperationIsNull(rawOperation)) 1461b56d1ec6SPeter Hawkins throw nb::python_error(); 14620126e906SJohn Demme MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 146378bd1246SAlex Zinenko return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 14640126e906SJohn Demme .releaseObject(); 14650126e906SJohn Demme } 14660126e906SJohn Demme 1467774818c0SDominik Grewe static void maybeInsertOperation(PyOperationRef &op, 1468b56d1ec6SPeter Hawkins const nb::object &maybeIp) { 1469774818c0SDominik Grewe // InsertPoint active? 1470b56d1ec6SPeter Hawkins if (!maybeIp.is(nb::cast(false))) { 1471774818c0SDominik Grewe PyInsertionPoint *ip; 1472774818c0SDominik Grewe if (maybeIp.is_none()) { 1473774818c0SDominik Grewe ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1474774818c0SDominik Grewe } else { 1475b56d1ec6SPeter Hawkins ip = nb::cast<PyInsertionPoint *>(maybeIp); 1476774818c0SDominik Grewe } 1477774818c0SDominik Grewe if (ip) 1478774818c0SDominik Grewe ip->insert(*op.get()); 1479774818c0SDominik Grewe } 1480774818c0SDominik Grewe } 1481774818c0SDominik Grewe 1482f4125e02SPeter Hawkins nb::object PyOperation::create(std::string_view name, 14830a81ace0SKazu Hirata std::optional<std::vector<PyType *>> results, 1484*acde3f72SPeter Hawkins llvm::ArrayRef<MlirValue> operands, 1485b56d1ec6SPeter Hawkins std::optional<nb::dict> attributes, 14860a81ace0SKazu Hirata std::optional<std::vector<PyBlock *>> successors, 14870a81ace0SKazu Hirata int regions, DefaultingPyLocation location, 1488b56d1ec6SPeter Hawkins const nb::object &maybeIp, bool inferType) { 1489436c6c9cSStella Laurenzo llvm::SmallVector<MlirType, 4> mlirResults; 1490436c6c9cSStella Laurenzo llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 1491436c6c9cSStella Laurenzo llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 1492436c6c9cSStella Laurenzo 1493436c6c9cSStella Laurenzo // General parameter validation. 1494436c6c9cSStella Laurenzo if (regions < 0) 1495b56d1ec6SPeter Hawkins throw nb::value_error("number of regions must be >= 0"); 1496436c6c9cSStella Laurenzo 1497436c6c9cSStella Laurenzo // Unpack/validate results. 1498436c6c9cSStella Laurenzo if (results) { 1499436c6c9cSStella Laurenzo mlirResults.reserve(results->size()); 1500436c6c9cSStella Laurenzo for (PyType *result : *results) { 1501436c6c9cSStella Laurenzo // TODO: Verify result type originate from the same context. 1502436c6c9cSStella Laurenzo if (!result) 1503b56d1ec6SPeter Hawkins throw nb::value_error("result type cannot be None"); 1504436c6c9cSStella Laurenzo mlirResults.push_back(*result); 1505436c6c9cSStella Laurenzo } 1506436c6c9cSStella Laurenzo } 1507436c6c9cSStella Laurenzo // Unpack/validate attributes. 1508436c6c9cSStella Laurenzo if (attributes) { 1509436c6c9cSStella Laurenzo mlirAttributes.reserve(attributes->size()); 1510b56d1ec6SPeter Hawkins for (std::pair<nb::handle, nb::handle> it : *attributes) { 1511436c6c9cSStella Laurenzo std::string key; 1512436c6c9cSStella Laurenzo try { 1513b56d1ec6SPeter Hawkins key = nb::cast<std::string>(it.first); 1514b56d1ec6SPeter Hawkins } catch (nb::cast_error &err) { 1515436c6c9cSStella Laurenzo std::string msg = "Invalid attribute key (not a string) when " 1516436c6c9cSStella Laurenzo "attempting to create the operation \"" + 1517f4125e02SPeter Hawkins std::string(name) + "\" (" + err.what() + ")"; 1518b56d1ec6SPeter Hawkins throw nb::type_error(msg.c_str()); 1519436c6c9cSStella Laurenzo } 1520436c6c9cSStella Laurenzo try { 1521b56d1ec6SPeter Hawkins auto &attribute = nb::cast<PyAttribute &>(it.second); 1522436c6c9cSStella Laurenzo // TODO: Verify attribute originates from the same context. 1523436c6c9cSStella Laurenzo mlirAttributes.emplace_back(std::move(key), attribute); 1524b56d1ec6SPeter Hawkins } catch (nb::cast_error &err) { 1525b56d1ec6SPeter Hawkins std::string msg = "Invalid attribute value for the key \"" + key + 1526b56d1ec6SPeter Hawkins "\" when attempting to create the operation \"" + 1527f4125e02SPeter Hawkins std::string(name) + "\" (" + err.what() + ")"; 1528b56d1ec6SPeter Hawkins throw nb::type_error(msg.c_str()); 1529b56d1ec6SPeter Hawkins } catch (std::runtime_error &) { 1530436c6c9cSStella Laurenzo // This exception seems thrown when the value is "None". 1531436c6c9cSStella Laurenzo std::string msg = 1532436c6c9cSStella Laurenzo "Found an invalid (`None`?) attribute value for the key \"" + key + 1533f4125e02SPeter Hawkins "\" when attempting to create the operation \"" + 1534f4125e02SPeter Hawkins std::string(name) + "\""; 1535b56d1ec6SPeter Hawkins throw std::runtime_error(msg); 1536436c6c9cSStella Laurenzo } 1537436c6c9cSStella Laurenzo } 1538436c6c9cSStella Laurenzo } 1539436c6c9cSStella Laurenzo // Unpack/validate successors. 1540436c6c9cSStella Laurenzo if (successors) { 1541436c6c9cSStella Laurenzo mlirSuccessors.reserve(successors->size()); 1542436c6c9cSStella Laurenzo for (auto *successor : *successors) { 1543436c6c9cSStella Laurenzo // TODO: Verify successor originate from the same context. 1544436c6c9cSStella Laurenzo if (!successor) 1545b56d1ec6SPeter Hawkins throw nb::value_error("successor block cannot be None"); 1546436c6c9cSStella Laurenzo mlirSuccessors.push_back(successor->get()); 1547436c6c9cSStella Laurenzo } 1548436c6c9cSStella Laurenzo } 1549436c6c9cSStella Laurenzo 1550436c6c9cSStella Laurenzo // Apply unpacked/validated to the operation state. Beyond this 1551436c6c9cSStella Laurenzo // point, exceptions cannot be thrown or else the state will leak. 1552436c6c9cSStella Laurenzo MlirOperationState state = 1553436c6c9cSStella Laurenzo mlirOperationStateGet(toMlirStringRef(name), location); 1554*acde3f72SPeter Hawkins if (!operands.empty()) 1555*acde3f72SPeter Hawkins mlirOperationStateAddOperands(&state, operands.size(), operands.data()); 1556f573bc24SJacques Pienaar state.enableResultTypeInference = inferType; 1557436c6c9cSStella Laurenzo if (!mlirResults.empty()) 1558436c6c9cSStella Laurenzo mlirOperationStateAddResults(&state, mlirResults.size(), 1559436c6c9cSStella Laurenzo mlirResults.data()); 1560436c6c9cSStella Laurenzo if (!mlirAttributes.empty()) { 1561436c6c9cSStella Laurenzo // Note that the attribute names directly reference bytes in 1562436c6c9cSStella Laurenzo // mlirAttributes, so that vector must not be changed from here 1563436c6c9cSStella Laurenzo // on. 1564436c6c9cSStella Laurenzo llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1565436c6c9cSStella Laurenzo mlirNamedAttributes.reserve(mlirAttributes.size()); 1566436c6c9cSStella Laurenzo for (auto &it : mlirAttributes) 1567436c6c9cSStella Laurenzo mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1568436c6c9cSStella Laurenzo mlirIdentifierGet(mlirAttributeGetContext(it.second), 1569436c6c9cSStella Laurenzo toMlirStringRef(it.first)), 1570436c6c9cSStella Laurenzo it.second)); 1571436c6c9cSStella Laurenzo mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1572436c6c9cSStella Laurenzo mlirNamedAttributes.data()); 1573436c6c9cSStella Laurenzo } 1574436c6c9cSStella Laurenzo if (!mlirSuccessors.empty()) 1575436c6c9cSStella Laurenzo mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1576436c6c9cSStella Laurenzo mlirSuccessors.data()); 1577436c6c9cSStella Laurenzo if (regions) { 1578436c6c9cSStella Laurenzo llvm::SmallVector<MlirRegion, 4> mlirRegions; 1579436c6c9cSStella Laurenzo mlirRegions.resize(regions); 1580436c6c9cSStella Laurenzo for (int i = 0; i < regions; ++i) 1581436c6c9cSStella Laurenzo mlirRegions[i] = mlirRegionCreate(); 1582436c6c9cSStella Laurenzo mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1583436c6c9cSStella Laurenzo mlirRegions.data()); 1584436c6c9cSStella Laurenzo } 1585436c6c9cSStella Laurenzo 1586436c6c9cSStella Laurenzo // Construct the operation. 1587436c6c9cSStella Laurenzo MlirOperation operation = mlirOperationCreate(&state); 1588f573bc24SJacques Pienaar if (!operation.ptr) 1589b56d1ec6SPeter Hawkins throw nb::value_error("Operation creation failed"); 1590436c6c9cSStella Laurenzo PyOperationRef created = 1591436c6c9cSStella Laurenzo PyOperation::createDetached(location->getContext(), operation); 1592774818c0SDominik Grewe maybeInsertOperation(created, maybeIp); 1593436c6c9cSStella Laurenzo 159400a93e62SPeter Hawkins return created.getObject(); 1595436c6c9cSStella Laurenzo } 1596436c6c9cSStella Laurenzo 1597b56d1ec6SPeter Hawkins nb::object PyOperation::clone(const nb::object &maybeIp) { 1598774818c0SDominik Grewe MlirOperation clonedOperation = mlirOperationClone(operation); 1599774818c0SDominik Grewe PyOperationRef cloned = 1600774818c0SDominik Grewe PyOperation::createDetached(getContext(), clonedOperation); 1601774818c0SDominik Grewe maybeInsertOperation(cloned, maybeIp); 1602774818c0SDominik Grewe 1603774818c0SDominik Grewe return cloned->createOpView(); 1604774818c0SDominik Grewe } 1605774818c0SDominik Grewe 1606b56d1ec6SPeter Hawkins nb::object PyOperation::createOpView() { 160749745f87SMike Urbach checkValid(); 1608436c6c9cSStella Laurenzo MlirIdentifier ident = mlirOperationGetName(get()); 1609436c6c9cSStella Laurenzo MlirStringRef identStr = mlirIdentifierStr(ident); 1610a7f8b7cdSRahul Kayaith auto operationCls = PyGlobals::get().lookupOperationClass( 1611436c6c9cSStella Laurenzo StringRef(identStr.data, identStr.length)); 1612a7f8b7cdSRahul Kayaith if (operationCls) 1613b56d1ec6SPeter Hawkins return PyOpView::constructDerived(*operationCls, getRef().getObject()); 1614b56d1ec6SPeter Hawkins return nb::cast(PyOpView(getRef().getObject())); 1615436c6c9cSStella Laurenzo } 1616436c6c9cSStella Laurenzo 161749745f87SMike Urbach void PyOperation::erase() { 161849745f87SMike Urbach checkValid(); 161967897d77SOleksandr "Alex" Zinenko getContext()->clearOperationAndInside(*this); 162049745f87SMike Urbach mlirOperationDestroy(operation); 162149745f87SMike Urbach } 162249745f87SMike Urbach 1623*acde3f72SPeter Hawkins namespace { 1624*acde3f72SPeter Hawkins /// CRTP base class for Python MLIR values that subclass Value and should be 1625*acde3f72SPeter Hawkins /// castable from it. The value hierarchy is one level deep and is not supposed 1626*acde3f72SPeter Hawkins /// to accommodate other levels unless core MLIR changes. 1627*acde3f72SPeter Hawkins template <typename DerivedTy> 1628*acde3f72SPeter Hawkins class PyConcreteValue : public PyValue { 1629*acde3f72SPeter Hawkins public: 1630*acde3f72SPeter Hawkins // Derived classes must define statics for: 1631*acde3f72SPeter Hawkins // IsAFunctionTy isaFunction 1632*acde3f72SPeter Hawkins // const char *pyClassName 1633*acde3f72SPeter Hawkins // and redefine bindDerived. 1634*acde3f72SPeter Hawkins using ClassTy = nb::class_<DerivedTy, PyValue>; 1635*acde3f72SPeter Hawkins using IsAFunctionTy = bool (*)(MlirValue); 1636*acde3f72SPeter Hawkins 1637*acde3f72SPeter Hawkins PyConcreteValue() = default; 1638*acde3f72SPeter Hawkins PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1639*acde3f72SPeter Hawkins : PyValue(operationRef, value) {} 1640*acde3f72SPeter Hawkins PyConcreteValue(PyValue &orig) 1641*acde3f72SPeter Hawkins : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1642*acde3f72SPeter Hawkins 1643*acde3f72SPeter Hawkins /// Attempts to cast the original value to the derived type and throws on 1644*acde3f72SPeter Hawkins /// type mismatches. 1645*acde3f72SPeter Hawkins static MlirValue castFrom(PyValue &orig) { 1646*acde3f72SPeter Hawkins if (!DerivedTy::isaFunction(orig.get())) { 1647*acde3f72SPeter Hawkins auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig))); 1648*acde3f72SPeter Hawkins throw nb::value_error((Twine("Cannot cast value to ") + 1649*acde3f72SPeter Hawkins DerivedTy::pyClassName + " (from " + origRepr + 1650*acde3f72SPeter Hawkins ")") 1651*acde3f72SPeter Hawkins .str() 1652*acde3f72SPeter Hawkins .c_str()); 1653*acde3f72SPeter Hawkins } 1654*acde3f72SPeter Hawkins return orig.get(); 1655*acde3f72SPeter Hawkins } 1656*acde3f72SPeter Hawkins 1657*acde3f72SPeter Hawkins /// Binds the Python module objects to functions of this class. 1658*acde3f72SPeter Hawkins static void bind(nb::module_ &m) { 1659*acde3f72SPeter Hawkins auto cls = ClassTy(m, DerivedTy::pyClassName); 1660*acde3f72SPeter Hawkins cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value")); 1661*acde3f72SPeter Hawkins cls.def_static( 1662*acde3f72SPeter Hawkins "isinstance", 1663*acde3f72SPeter Hawkins [](PyValue &otherValue) -> bool { 1664*acde3f72SPeter Hawkins return DerivedTy::isaFunction(otherValue); 1665*acde3f72SPeter Hawkins }, 1666*acde3f72SPeter Hawkins nb::arg("other_value")); 1667*acde3f72SPeter Hawkins cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, 1668*acde3f72SPeter Hawkins [](DerivedTy &self) { return self.maybeDownCast(); }); 1669*acde3f72SPeter Hawkins DerivedTy::bindDerived(cls); 1670*acde3f72SPeter Hawkins } 1671*acde3f72SPeter Hawkins 1672*acde3f72SPeter Hawkins /// Implemented by derived classes to add methods to the Python subclass. 1673*acde3f72SPeter Hawkins static void bindDerived(ClassTy &m) {} 1674*acde3f72SPeter Hawkins }; 1675*acde3f72SPeter Hawkins 1676*acde3f72SPeter Hawkins } // namespace 1677*acde3f72SPeter Hawkins 1678*acde3f72SPeter Hawkins /// Python wrapper for MlirOpResult. 1679*acde3f72SPeter Hawkins class PyOpResult : public PyConcreteValue<PyOpResult> { 1680*acde3f72SPeter Hawkins public: 1681*acde3f72SPeter Hawkins static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1682*acde3f72SPeter Hawkins static constexpr const char *pyClassName = "OpResult"; 1683*acde3f72SPeter Hawkins using PyConcreteValue::PyConcreteValue; 1684*acde3f72SPeter Hawkins 1685*acde3f72SPeter Hawkins static void bindDerived(ClassTy &c) { 1686*acde3f72SPeter Hawkins c.def_prop_ro("owner", [](PyOpResult &self) { 1687*acde3f72SPeter Hawkins assert( 1688*acde3f72SPeter Hawkins mlirOperationEqual(self.getParentOperation()->get(), 1689*acde3f72SPeter Hawkins mlirOpResultGetOwner(self.get())) && 1690*acde3f72SPeter Hawkins "expected the owner of the value in Python to match that in the IR"); 1691*acde3f72SPeter Hawkins return self.getParentOperation().getObject(); 1692*acde3f72SPeter Hawkins }); 1693*acde3f72SPeter Hawkins c.def_prop_ro("result_number", [](PyOpResult &self) { 1694*acde3f72SPeter Hawkins return mlirOpResultGetResultNumber(self.get()); 1695*acde3f72SPeter Hawkins }); 1696*acde3f72SPeter Hawkins } 1697*acde3f72SPeter Hawkins }; 1698*acde3f72SPeter Hawkins 1699*acde3f72SPeter Hawkins /// Returns the list of types of the values held by container. 1700*acde3f72SPeter Hawkins template <typename Container> 1701*acde3f72SPeter Hawkins static std::vector<MlirType> getValueTypes(Container &container, 1702*acde3f72SPeter Hawkins PyMlirContextRef &context) { 1703*acde3f72SPeter Hawkins std::vector<MlirType> result; 1704*acde3f72SPeter Hawkins result.reserve(container.size()); 1705*acde3f72SPeter Hawkins for (int i = 0, e = container.size(); i < e; ++i) { 1706*acde3f72SPeter Hawkins result.push_back(mlirValueGetType(container.getElement(i).get())); 1707*acde3f72SPeter Hawkins } 1708*acde3f72SPeter Hawkins return result; 1709*acde3f72SPeter Hawkins } 1710*acde3f72SPeter Hawkins 1711*acde3f72SPeter Hawkins /// A list of operation results. Internally, these are stored as consecutive 1712*acde3f72SPeter Hawkins /// elements, random access is cheap. The (returned) result list is associated 1713*acde3f72SPeter Hawkins /// with the operation whose results these are, and thus extends the lifetime of 1714*acde3f72SPeter Hawkins /// this operation. 1715*acde3f72SPeter Hawkins class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1716*acde3f72SPeter Hawkins public: 1717*acde3f72SPeter Hawkins static constexpr const char *pyClassName = "OpResultList"; 1718*acde3f72SPeter Hawkins using SliceableT = Sliceable<PyOpResultList, PyOpResult>; 1719*acde3f72SPeter Hawkins 1720*acde3f72SPeter Hawkins PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1721*acde3f72SPeter Hawkins intptr_t length = -1, intptr_t step = 1) 1722*acde3f72SPeter Hawkins : Sliceable(startIndex, 1723*acde3f72SPeter Hawkins length == -1 ? mlirOperationGetNumResults(operation->get()) 1724*acde3f72SPeter Hawkins : length, 1725*acde3f72SPeter Hawkins step), 1726*acde3f72SPeter Hawkins operation(std::move(operation)) {} 1727*acde3f72SPeter Hawkins 1728*acde3f72SPeter Hawkins static void bindDerived(ClassTy &c) { 1729*acde3f72SPeter Hawkins c.def_prop_ro("types", [](PyOpResultList &self) { 1730*acde3f72SPeter Hawkins return getValueTypes(self, self.operation->getContext()); 1731*acde3f72SPeter Hawkins }); 1732*acde3f72SPeter Hawkins c.def_prop_ro("owner", [](PyOpResultList &self) { 1733*acde3f72SPeter Hawkins return self.operation->createOpView(); 1734*acde3f72SPeter Hawkins }); 1735*acde3f72SPeter Hawkins } 1736*acde3f72SPeter Hawkins 1737*acde3f72SPeter Hawkins PyOperationRef &getOperation() { return operation; } 1738*acde3f72SPeter Hawkins 1739*acde3f72SPeter Hawkins private: 1740*acde3f72SPeter Hawkins /// Give the parent CRTP class access to hook implementations below. 1741*acde3f72SPeter Hawkins friend class Sliceable<PyOpResultList, PyOpResult>; 1742*acde3f72SPeter Hawkins 1743*acde3f72SPeter Hawkins intptr_t getRawNumElements() { 1744*acde3f72SPeter Hawkins operation->checkValid(); 1745*acde3f72SPeter Hawkins return mlirOperationGetNumResults(operation->get()); 1746*acde3f72SPeter Hawkins } 1747*acde3f72SPeter Hawkins 1748*acde3f72SPeter Hawkins PyOpResult getRawElement(intptr_t index) { 1749*acde3f72SPeter Hawkins PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1750*acde3f72SPeter Hawkins return PyOpResult(value); 1751*acde3f72SPeter Hawkins } 1752*acde3f72SPeter Hawkins 1753*acde3f72SPeter Hawkins PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1754*acde3f72SPeter Hawkins return PyOpResultList(operation, startIndex, length, step); 1755*acde3f72SPeter Hawkins } 1756*acde3f72SPeter Hawkins 1757*acde3f72SPeter Hawkins PyOperationRef operation; 1758*acde3f72SPeter Hawkins }; 1759*acde3f72SPeter Hawkins 1760436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1761436c6c9cSStella Laurenzo // PyOpView 1762436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1763436c6c9cSStella Laurenzo 1764b56d1ec6SPeter Hawkins static void populateResultTypes(StringRef name, nb::list resultTypeList, 1765b56d1ec6SPeter Hawkins const nb::object &resultSegmentSpecObj, 1766f573bc24SJacques Pienaar std::vector<int32_t> &resultSegmentLengths, 1767f573bc24SJacques Pienaar std::vector<PyType *> &resultTypes) { 1768436c6c9cSStella Laurenzo resultTypes.reserve(resultTypeList.size()); 1769436c6c9cSStella Laurenzo if (resultSegmentSpecObj.is_none()) { 1770436c6c9cSStella Laurenzo // Non-variadic result unpacking. 1771e4853be2SMehdi Amini for (const auto &it : llvm::enumerate(resultTypeList)) { 1772436c6c9cSStella Laurenzo try { 1773b56d1ec6SPeter Hawkins resultTypes.push_back(nb::cast<PyType *>(it.value())); 1774436c6c9cSStella Laurenzo if (!resultTypes.back()) 1775b56d1ec6SPeter Hawkins throw nb::cast_error(); 1776b56d1ec6SPeter Hawkins } catch (nb::cast_error &err) { 1777b56d1ec6SPeter Hawkins throw nb::value_error((llvm::Twine("Result ") + 1778436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1779436c6c9cSStella Laurenzo name + "\" must be a Type (" + err.what() + ")") 1780b56d1ec6SPeter Hawkins .str() 1781b56d1ec6SPeter Hawkins .c_str()); 1782436c6c9cSStella Laurenzo } 1783436c6c9cSStella Laurenzo } 1784436c6c9cSStella Laurenzo } else { 1785436c6c9cSStella Laurenzo // Sized result unpacking. 1786b56d1ec6SPeter Hawkins auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj); 1787436c6c9cSStella Laurenzo if (resultSegmentSpec.size() != resultTypeList.size()) { 1788b56d1ec6SPeter Hawkins throw nb::value_error((llvm::Twine("Operation \"") + name + 1789436c6c9cSStella Laurenzo "\" requires " + 1790436c6c9cSStella Laurenzo llvm::Twine(resultSegmentSpec.size()) + 1791436c6c9cSStella Laurenzo " result segments but was provided " + 1792436c6c9cSStella Laurenzo llvm::Twine(resultTypeList.size())) 1793b56d1ec6SPeter Hawkins .str() 1794b56d1ec6SPeter Hawkins .c_str()); 1795436c6c9cSStella Laurenzo } 1796436c6c9cSStella Laurenzo resultSegmentLengths.reserve(resultTypeList.size()); 1797e4853be2SMehdi Amini for (const auto &it : 1798436c6c9cSStella Laurenzo llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1799436c6c9cSStella Laurenzo int segmentSpec = std::get<1>(it.value()); 1800436c6c9cSStella Laurenzo if (segmentSpec == 1 || segmentSpec == 0) { 1801436c6c9cSStella Laurenzo // Unpack unary element. 1802436c6c9cSStella Laurenzo try { 1803b56d1ec6SPeter Hawkins auto *resultType = nb::cast<PyType *>(std::get<0>(it.value())); 1804436c6c9cSStella Laurenzo if (resultType) { 1805436c6c9cSStella Laurenzo resultTypes.push_back(resultType); 1806436c6c9cSStella Laurenzo resultSegmentLengths.push_back(1); 1807436c6c9cSStella Laurenzo } else if (segmentSpec == 0) { 1808436c6c9cSStella Laurenzo // Allowed to be optional. 1809436c6c9cSStella Laurenzo resultSegmentLengths.push_back(0); 1810436c6c9cSStella Laurenzo } else { 1811b56d1ec6SPeter Hawkins throw nb::value_error( 1812b56d1ec6SPeter Hawkins (llvm::Twine("Result ") + llvm::Twine(it.index()) + 1813b56d1ec6SPeter Hawkins " of operation \"" + name + 1814b56d1ec6SPeter Hawkins "\" must be a Type (was None and result is not optional)") 1815b56d1ec6SPeter Hawkins .str() 1816b56d1ec6SPeter Hawkins .c_str()); 1817436c6c9cSStella Laurenzo } 1818b56d1ec6SPeter Hawkins } catch (nb::cast_error &err) { 1819b56d1ec6SPeter Hawkins throw nb::value_error((llvm::Twine("Result ") + 1820436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1821436c6c9cSStella Laurenzo name + "\" must be a Type (" + err.what() + 1822436c6c9cSStella Laurenzo ")") 1823b56d1ec6SPeter Hawkins .str() 1824b56d1ec6SPeter Hawkins .c_str()); 1825436c6c9cSStella Laurenzo } 1826436c6c9cSStella Laurenzo } else if (segmentSpec == -1) { 1827436c6c9cSStella Laurenzo // Unpack sequence by appending. 1828436c6c9cSStella Laurenzo try { 1829436c6c9cSStella Laurenzo if (std::get<0>(it.value()).is_none()) { 1830436c6c9cSStella Laurenzo // Treat it as an empty list. 1831436c6c9cSStella Laurenzo resultSegmentLengths.push_back(0); 1832436c6c9cSStella Laurenzo } else { 1833436c6c9cSStella Laurenzo // Unpack the list. 1834b56d1ec6SPeter Hawkins auto segment = nb::cast<nb::sequence>(std::get<0>(it.value())); 1835b56d1ec6SPeter Hawkins for (nb::handle segmentItem : segment) { 1836b56d1ec6SPeter Hawkins resultTypes.push_back(nb::cast<PyType *>(segmentItem)); 1837436c6c9cSStella Laurenzo if (!resultTypes.back()) { 1838b56d1ec6SPeter Hawkins throw nb::type_error("contained a None item"); 1839436c6c9cSStella Laurenzo } 1840436c6c9cSStella Laurenzo } 1841b56d1ec6SPeter Hawkins resultSegmentLengths.push_back(nb::len(segment)); 1842436c6c9cSStella Laurenzo } 1843436c6c9cSStella Laurenzo } catch (std::exception &err) { 1844436c6c9cSStella Laurenzo // NOTE: Sloppy to be using a catch-all here, but there are at least 1845436c6c9cSStella Laurenzo // three different unrelated exceptions that can be thrown in the 1846436c6c9cSStella Laurenzo // above "casts". Just keep the scope above small and catch them all. 1847b56d1ec6SPeter Hawkins throw nb::value_error((llvm::Twine("Result ") + 1848436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1849436c6c9cSStella Laurenzo name + "\" must be a Sequence of Types (" + 1850436c6c9cSStella Laurenzo err.what() + ")") 1851b56d1ec6SPeter Hawkins .str() 1852b56d1ec6SPeter Hawkins .c_str()); 1853436c6c9cSStella Laurenzo } 1854436c6c9cSStella Laurenzo } else { 1855b56d1ec6SPeter Hawkins throw nb::value_error("Unexpected segment spec"); 1856436c6c9cSStella Laurenzo } 1857436c6c9cSStella Laurenzo } 1858436c6c9cSStella Laurenzo } 1859f573bc24SJacques Pienaar } 1860f573bc24SJacques Pienaar 1861*acde3f72SPeter Hawkins static MlirValue getUniqueResult(MlirOperation operation) { 1862*acde3f72SPeter Hawkins auto numResults = mlirOperationGetNumResults(operation); 1863*acde3f72SPeter Hawkins if (numResults != 1) { 1864*acde3f72SPeter Hawkins auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 1865*acde3f72SPeter Hawkins throw nb::value_error((Twine("Cannot call .result on operation ") + 1866*acde3f72SPeter Hawkins StringRef(name.data, name.length) + " which has " + 1867*acde3f72SPeter Hawkins Twine(numResults) + 1868*acde3f72SPeter Hawkins " results (it is only valid for operations with a " 1869*acde3f72SPeter Hawkins "single result)") 1870*acde3f72SPeter Hawkins .str() 1871*acde3f72SPeter Hawkins .c_str()); 1872*acde3f72SPeter Hawkins } 1873*acde3f72SPeter Hawkins return mlirOperationGetResult(operation, 0); 1874*acde3f72SPeter Hawkins } 1875*acde3f72SPeter Hawkins 1876*acde3f72SPeter Hawkins static MlirValue getOpResultOrValue(nb::handle operand) { 1877*acde3f72SPeter Hawkins if (operand.is_none()) { 1878*acde3f72SPeter Hawkins throw nb::value_error("contained a None item"); 1879*acde3f72SPeter Hawkins } 1880*acde3f72SPeter Hawkins PyOperationBase *op; 1881*acde3f72SPeter Hawkins if (nb::try_cast<PyOperationBase *>(operand, op)) { 1882*acde3f72SPeter Hawkins return getUniqueResult(op->getOperation()); 1883*acde3f72SPeter Hawkins } 1884*acde3f72SPeter Hawkins PyOpResultList *opResultList; 1885*acde3f72SPeter Hawkins if (nb::try_cast<PyOpResultList *>(operand, opResultList)) { 1886*acde3f72SPeter Hawkins return getUniqueResult(opResultList->getOperation()->get()); 1887*acde3f72SPeter Hawkins } 1888*acde3f72SPeter Hawkins PyValue *value; 1889*acde3f72SPeter Hawkins if (nb::try_cast<PyValue *>(operand, value)) { 1890*acde3f72SPeter Hawkins return value->get(); 1891*acde3f72SPeter Hawkins } 1892*acde3f72SPeter Hawkins throw nb::value_error("is not a Value"); 1893*acde3f72SPeter Hawkins } 1894*acde3f72SPeter Hawkins 1895b56d1ec6SPeter Hawkins nb::object PyOpView::buildGeneric( 1896f4125e02SPeter Hawkins std::string_view name, std::tuple<int, bool> opRegionSpec, 1897f4125e02SPeter Hawkins nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj, 1898f4125e02SPeter Hawkins std::optional<nb::list> resultTypeList, nb::list operandList, 1899f4125e02SPeter Hawkins std::optional<nb::dict> attributes, 1900f573bc24SJacques Pienaar std::optional<std::vector<PyBlock *>> successors, 1901f573bc24SJacques Pienaar std::optional<int> regions, DefaultingPyLocation location, 1902b56d1ec6SPeter Hawkins const nb::object &maybeIp) { 1903f573bc24SJacques Pienaar PyMlirContextRef context = location->getContext(); 1904f4125e02SPeter Hawkins 1905f573bc24SJacques Pienaar // Class level operation construction metadata. 1906f573bc24SJacques Pienaar // Operand and result segment specs are either none, which does no 1907f573bc24SJacques Pienaar // variadic unpacking, or a list of ints with segment sizes, where each 1908f573bc24SJacques Pienaar // element is either a positive number (typically 1 for a scalar) or -1 to 1909f573bc24SJacques Pienaar // indicate that it is derived from the length of the same-indexed operand 1910f573bc24SJacques Pienaar // or result (implying that it is a list at that position). 1911f573bc24SJacques Pienaar std::vector<int32_t> operandSegmentLengths; 1912f573bc24SJacques Pienaar std::vector<int32_t> resultSegmentLengths; 1913f573bc24SJacques Pienaar 1914f573bc24SJacques Pienaar // Validate/determine region count. 1915f573bc24SJacques Pienaar int opMinRegionCount = std::get<0>(opRegionSpec); 1916f573bc24SJacques Pienaar bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1917f573bc24SJacques Pienaar if (!regions) { 1918f573bc24SJacques Pienaar regions = opMinRegionCount; 1919f573bc24SJacques Pienaar } 1920f573bc24SJacques Pienaar if (*regions < opMinRegionCount) { 1921b56d1ec6SPeter Hawkins throw nb::value_error( 1922f573bc24SJacques Pienaar (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1923f573bc24SJacques Pienaar llvm::Twine(opMinRegionCount) + 1924f573bc24SJacques Pienaar " regions but was built with regions=" + llvm::Twine(*regions)) 1925b56d1ec6SPeter Hawkins .str() 1926b56d1ec6SPeter Hawkins .c_str()); 1927f573bc24SJacques Pienaar } 1928f573bc24SJacques Pienaar if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1929b56d1ec6SPeter Hawkins throw nb::value_error( 1930f573bc24SJacques Pienaar (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1931f573bc24SJacques Pienaar llvm::Twine(opMinRegionCount) + 1932f573bc24SJacques Pienaar " regions but was built with regions=" + llvm::Twine(*regions)) 1933b56d1ec6SPeter Hawkins .str() 1934b56d1ec6SPeter Hawkins .c_str()); 1935f573bc24SJacques Pienaar } 1936f573bc24SJacques Pienaar 1937f573bc24SJacques Pienaar // Unpack results. 1938f573bc24SJacques Pienaar std::vector<PyType *> resultTypes; 1939f573bc24SJacques Pienaar if (resultTypeList.has_value()) { 1940f573bc24SJacques Pienaar populateResultTypes(name, *resultTypeList, resultSegmentSpecObj, 1941f573bc24SJacques Pienaar resultSegmentLengths, resultTypes); 1942f573bc24SJacques Pienaar } 1943436c6c9cSStella Laurenzo 1944436c6c9cSStella Laurenzo // Unpack operands. 1945*acde3f72SPeter Hawkins llvm::SmallVector<MlirValue, 4> operands; 1946436c6c9cSStella Laurenzo operands.reserve(operands.size()); 1947436c6c9cSStella Laurenzo if (operandSegmentSpecObj.is_none()) { 1948436c6c9cSStella Laurenzo // Non-sized operand unpacking. 1949e4853be2SMehdi Amini for (const auto &it : llvm::enumerate(operandList)) { 1950436c6c9cSStella Laurenzo try { 1951*acde3f72SPeter Hawkins operands.push_back(getOpResultOrValue(it.value())); 1952*acde3f72SPeter Hawkins } catch (nb::builtin_exception &err) { 1953b56d1ec6SPeter Hawkins throw nb::value_error((llvm::Twine("Operand ") + 1954436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1955436c6c9cSStella Laurenzo name + "\" must be a Value (" + err.what() + ")") 1956b56d1ec6SPeter Hawkins .str() 1957b56d1ec6SPeter Hawkins .c_str()); 1958436c6c9cSStella Laurenzo } 1959436c6c9cSStella Laurenzo } 1960436c6c9cSStella Laurenzo } else { 1961436c6c9cSStella Laurenzo // Sized operand unpacking. 1962b56d1ec6SPeter Hawkins auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj); 1963436c6c9cSStella Laurenzo if (operandSegmentSpec.size() != operandList.size()) { 1964b56d1ec6SPeter Hawkins throw nb::value_error((llvm::Twine("Operation \"") + name + 1965436c6c9cSStella Laurenzo "\" requires " + 1966436c6c9cSStella Laurenzo llvm::Twine(operandSegmentSpec.size()) + 1967436c6c9cSStella Laurenzo "operand segments but was provided " + 1968436c6c9cSStella Laurenzo llvm::Twine(operandList.size())) 1969b56d1ec6SPeter Hawkins .str() 1970b56d1ec6SPeter Hawkins .c_str()); 1971436c6c9cSStella Laurenzo } 1972436c6c9cSStella Laurenzo operandSegmentLengths.reserve(operandList.size()); 1973e4853be2SMehdi Amini for (const auto &it : 1974436c6c9cSStella Laurenzo llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1975436c6c9cSStella Laurenzo int segmentSpec = std::get<1>(it.value()); 1976436c6c9cSStella Laurenzo if (segmentSpec == 1 || segmentSpec == 0) { 1977436c6c9cSStella Laurenzo // Unpack unary element. 1978*acde3f72SPeter Hawkins auto &operand = std::get<0>(it.value()); 1979*acde3f72SPeter Hawkins if (!operand.is_none()) { 1980436c6c9cSStella Laurenzo try { 1981*acde3f72SPeter Hawkins 1982*acde3f72SPeter Hawkins operands.push_back(getOpResultOrValue(operand)); 1983*acde3f72SPeter Hawkins } catch (nb::builtin_exception &err) { 1984*acde3f72SPeter Hawkins throw nb::value_error((llvm::Twine("Operand ") + 1985*acde3f72SPeter Hawkins llvm::Twine(it.index()) + 1986*acde3f72SPeter Hawkins " of operation \"" + name + 1987*acde3f72SPeter Hawkins "\" must be a Value (" + err.what() + ")") 1988*acde3f72SPeter Hawkins .str() 1989*acde3f72SPeter Hawkins .c_str()); 1990*acde3f72SPeter Hawkins } 1991*acde3f72SPeter Hawkins 1992436c6c9cSStella Laurenzo operandSegmentLengths.push_back(1); 1993436c6c9cSStella Laurenzo } else if (segmentSpec == 0) { 1994436c6c9cSStella Laurenzo // Allowed to be optional. 1995436c6c9cSStella Laurenzo operandSegmentLengths.push_back(0); 1996436c6c9cSStella Laurenzo } else { 1997b56d1ec6SPeter Hawkins throw nb::value_error( 1998b56d1ec6SPeter Hawkins (llvm::Twine("Operand ") + llvm::Twine(it.index()) + 1999b56d1ec6SPeter Hawkins " of operation \"" + name + 2000b56d1ec6SPeter Hawkins "\" must be a Value (was None and operand is not optional)") 2001b56d1ec6SPeter Hawkins .str() 2002b56d1ec6SPeter Hawkins .c_str()); 2003436c6c9cSStella Laurenzo } 2004436c6c9cSStella Laurenzo } else if (segmentSpec == -1) { 2005436c6c9cSStella Laurenzo // Unpack sequence by appending. 2006436c6c9cSStella Laurenzo try { 2007436c6c9cSStella Laurenzo if (std::get<0>(it.value()).is_none()) { 2008436c6c9cSStella Laurenzo // Treat it as an empty list. 2009436c6c9cSStella Laurenzo operandSegmentLengths.push_back(0); 2010436c6c9cSStella Laurenzo } else { 2011436c6c9cSStella Laurenzo // Unpack the list. 2012b56d1ec6SPeter Hawkins auto segment = nb::cast<nb::sequence>(std::get<0>(it.value())); 2013b56d1ec6SPeter Hawkins for (nb::handle segmentItem : segment) { 2014*acde3f72SPeter Hawkins operands.push_back(getOpResultOrValue(segmentItem)); 2015436c6c9cSStella Laurenzo } 2016b56d1ec6SPeter Hawkins operandSegmentLengths.push_back(nb::len(segment)); 2017436c6c9cSStella Laurenzo } 2018436c6c9cSStella Laurenzo } catch (std::exception &err) { 2019436c6c9cSStella Laurenzo // NOTE: Sloppy to be using a catch-all here, but there are at least 2020436c6c9cSStella Laurenzo // three different unrelated exceptions that can be thrown in the 2021436c6c9cSStella Laurenzo // above "casts". Just keep the scope above small and catch them all. 2022b56d1ec6SPeter Hawkins throw nb::value_error((llvm::Twine("Operand ") + 2023436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 2024436c6c9cSStella Laurenzo name + "\" must be a Sequence of Values (" + 2025436c6c9cSStella Laurenzo err.what() + ")") 2026b56d1ec6SPeter Hawkins .str() 2027b56d1ec6SPeter Hawkins .c_str()); 2028436c6c9cSStella Laurenzo } 2029436c6c9cSStella Laurenzo } else { 2030b56d1ec6SPeter Hawkins throw nb::value_error("Unexpected segment spec"); 2031436c6c9cSStella Laurenzo } 2032436c6c9cSStella Laurenzo } 2033436c6c9cSStella Laurenzo } 2034436c6c9cSStella Laurenzo 2035436c6c9cSStella Laurenzo // Merge operand/result segment lengths into attributes if needed. 2036436c6c9cSStella Laurenzo if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 2037436c6c9cSStella Laurenzo // Dup. 2038436c6c9cSStella Laurenzo if (attributes) { 2039b56d1ec6SPeter Hawkins attributes = nb::dict(*attributes); 2040436c6c9cSStella Laurenzo } else { 2041b56d1ec6SPeter Hawkins attributes = nb::dict(); 2042436c6c9cSStella Laurenzo } 2043363b6559SMehdi Amini if (attributes->contains("resultSegmentSizes") || 2044363b6559SMehdi Amini attributes->contains("operandSegmentSizes")) { 2045b56d1ec6SPeter Hawkins throw nb::value_error("Manually setting a 'resultSegmentSizes' or " 2046363b6559SMehdi Amini "'operandSegmentSizes' attribute is unsupported. " 2047436c6c9cSStella Laurenzo "Use Operation.create for such low-level access."); 2048436c6c9cSStella Laurenzo } 2049436c6c9cSStella Laurenzo 2050363b6559SMehdi Amini // Add resultSegmentSizes attribute. 2051436c6c9cSStella Laurenzo if (!resultSegmentLengths.empty()) { 205258a47508SJeff Niu MlirAttribute segmentLengthAttr = 205358a47508SJeff Niu mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(), 205458a47508SJeff Niu resultSegmentLengths.data()); 2055363b6559SMehdi Amini (*attributes)["resultSegmentSizes"] = 2056436c6c9cSStella Laurenzo PyAttribute(context, segmentLengthAttr); 2057436c6c9cSStella Laurenzo } 2058436c6c9cSStella Laurenzo 2059363b6559SMehdi Amini // Add operandSegmentSizes attribute. 2060436c6c9cSStella Laurenzo if (!operandSegmentLengths.empty()) { 206158a47508SJeff Niu MlirAttribute segmentLengthAttr = 206258a47508SJeff Niu mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(), 206358a47508SJeff Niu operandSegmentLengths.data()); 2064363b6559SMehdi Amini (*attributes)["operandSegmentSizes"] = 2065436c6c9cSStella Laurenzo PyAttribute(context, segmentLengthAttr); 2066436c6c9cSStella Laurenzo } 2067436c6c9cSStella Laurenzo } 2068436c6c9cSStella Laurenzo 2069436c6c9cSStella Laurenzo // Delegate to create. 2070337c937dSMehdi Amini return PyOperation::create(name, 2071436c6c9cSStella Laurenzo /*results=*/std::move(resultTypes), 2072436c6c9cSStella Laurenzo /*operands=*/std::move(operands), 2073436c6c9cSStella Laurenzo /*attributes=*/std::move(attributes), 2074436c6c9cSStella Laurenzo /*successors=*/std::move(successors), 2075f573bc24SJacques Pienaar /*regions=*/*regions, location, maybeIp, 2076f573bc24SJacques Pienaar !resultTypeList); 2077436c6c9cSStella Laurenzo } 2078436c6c9cSStella Laurenzo 2079b56d1ec6SPeter Hawkins nb::object PyOpView::constructDerived(const nb::object &cls, 2080b56d1ec6SPeter Hawkins const nb::object &operation) { 2081b56d1ec6SPeter Hawkins nb::handle opViewType = nb::type<PyOpView>(); 2082b56d1ec6SPeter Hawkins nb::object instance = cls.attr("__new__")(cls); 2083a7f8b7cdSRahul Kayaith opViewType.attr("__init__")(instance, operation); 2084a7f8b7cdSRahul Kayaith return instance; 2085a7f8b7cdSRahul Kayaith } 2086a7f8b7cdSRahul Kayaith 2087b56d1ec6SPeter Hawkins PyOpView::PyOpView(const nb::object &operationObject) 2088436c6c9cSStella Laurenzo // Casting through the PyOperationBase base-class and then back to the 2089436c6c9cSStella Laurenzo // Operation lets us accept any PyOperationBase subclass. 2090b56d1ec6SPeter Hawkins : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()), 2091436c6c9cSStella Laurenzo operationObject(operation.getRef().getObject()) {} 2092436c6c9cSStella Laurenzo 2093436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 2094436c6c9cSStella Laurenzo // PyInsertionPoint. 2095436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 2096436c6c9cSStella Laurenzo 2097436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 2098436c6c9cSStella Laurenzo 2099436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 2100436c6c9cSStella Laurenzo : refOperation(beforeOperationBase.getOperation().getRef()), 2101436c6c9cSStella Laurenzo block((*refOperation)->getBlock()) {} 2102436c6c9cSStella Laurenzo 2103436c6c9cSStella Laurenzo void PyInsertionPoint::insert(PyOperationBase &operationBase) { 2104436c6c9cSStella Laurenzo PyOperation &operation = operationBase.getOperation(); 2105436c6c9cSStella Laurenzo if (operation.isAttached()) 2106b56d1ec6SPeter Hawkins throw nb::value_error( 2107436c6c9cSStella Laurenzo "Attempt to insert operation that is already attached"); 2108436c6c9cSStella Laurenzo block.getParentOperation()->checkValid(); 2109436c6c9cSStella Laurenzo MlirOperation beforeOp = {nullptr}; 2110436c6c9cSStella Laurenzo if (refOperation) { 2111436c6c9cSStella Laurenzo // Insert before operation. 2112436c6c9cSStella Laurenzo (*refOperation)->checkValid(); 2113436c6c9cSStella Laurenzo beforeOp = (*refOperation)->get(); 2114436c6c9cSStella Laurenzo } else { 2115436c6c9cSStella Laurenzo // Insert at end (before null) is only valid if the block does not 2116436c6c9cSStella Laurenzo // already end in a known terminator (violating this will cause assertion 2117436c6c9cSStella Laurenzo // failures later). 2118436c6c9cSStella Laurenzo if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 2119b56d1ec6SPeter Hawkins throw nb::index_error("Cannot insert operation at the end of a block " 2120436c6c9cSStella Laurenzo "that already has a terminator. Did you mean to " 2121436c6c9cSStella Laurenzo "use 'InsertionPoint.at_block_terminator(block)' " 2122436c6c9cSStella Laurenzo "versus 'InsertionPoint(block)'?"); 2123436c6c9cSStella Laurenzo } 2124436c6c9cSStella Laurenzo } 2125436c6c9cSStella Laurenzo mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 2126436c6c9cSStella Laurenzo operation.setAttached(); 2127436c6c9cSStella Laurenzo } 2128436c6c9cSStella Laurenzo 2129436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 2130436c6c9cSStella Laurenzo MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 2131436c6c9cSStella Laurenzo if (mlirOperationIsNull(firstOp)) { 2132436c6c9cSStella Laurenzo // Just insert at end. 2133436c6c9cSStella Laurenzo return PyInsertionPoint(block); 2134436c6c9cSStella Laurenzo } 2135436c6c9cSStella Laurenzo 2136436c6c9cSStella Laurenzo // Insert before first op. 2137436c6c9cSStella Laurenzo PyOperationRef firstOpRef = PyOperation::forOperation( 2138436c6c9cSStella Laurenzo block.getParentOperation()->getContext(), firstOp); 2139436c6c9cSStella Laurenzo return PyInsertionPoint{block, std::move(firstOpRef)}; 2140436c6c9cSStella Laurenzo } 2141436c6c9cSStella Laurenzo 2142436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 2143436c6c9cSStella Laurenzo MlirOperation terminator = mlirBlockGetTerminator(block.get()); 2144436c6c9cSStella Laurenzo if (mlirOperationIsNull(terminator)) 2145b56d1ec6SPeter Hawkins throw nb::value_error("Block has no terminator"); 2146436c6c9cSStella Laurenzo PyOperationRef terminatorOpRef = PyOperation::forOperation( 2147436c6c9cSStella Laurenzo block.getParentOperation()->getContext(), terminator); 2148436c6c9cSStella Laurenzo return PyInsertionPoint{block, std::move(terminatorOpRef)}; 2149436c6c9cSStella Laurenzo } 2150436c6c9cSStella Laurenzo 2151b56d1ec6SPeter Hawkins nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) { 2152b56d1ec6SPeter Hawkins return PyThreadContextEntry::pushInsertionPoint(insertPoint); 2153436c6c9cSStella Laurenzo } 2154436c6c9cSStella Laurenzo 2155b56d1ec6SPeter Hawkins void PyInsertionPoint::contextExit(const nb::object &excType, 2156b56d1ec6SPeter Hawkins const nb::object &excVal, 2157b56d1ec6SPeter Hawkins const nb::object &excTb) { 2158436c6c9cSStella Laurenzo PyThreadContextEntry::popInsertionPoint(*this); 2159436c6c9cSStella Laurenzo } 2160436c6c9cSStella Laurenzo 2161436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 2162436c6c9cSStella Laurenzo // PyAttribute. 2163436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 2164436c6c9cSStella Laurenzo 2165e6d738e0SRahul Kayaith bool PyAttribute::operator==(const PyAttribute &other) const { 2166436c6c9cSStella Laurenzo return mlirAttributeEqual(attr, other.attr); 2167436c6c9cSStella Laurenzo } 2168436c6c9cSStella Laurenzo 2169b56d1ec6SPeter Hawkins nb::object PyAttribute::getCapsule() { 2170b56d1ec6SPeter Hawkins return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this)); 2171436c6c9cSStella Laurenzo } 2172436c6c9cSStella Laurenzo 2173b56d1ec6SPeter Hawkins PyAttribute PyAttribute::createFromCapsule(nb::object capsule) { 2174436c6c9cSStella Laurenzo MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 2175436c6c9cSStella Laurenzo if (mlirAttributeIsNull(rawAttr)) 2176b56d1ec6SPeter Hawkins throw nb::python_error(); 2177436c6c9cSStella Laurenzo return PyAttribute( 2178436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 2179436c6c9cSStella Laurenzo } 2180436c6c9cSStella Laurenzo 2181436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 2182436c6c9cSStella Laurenzo // PyNamedAttribute. 2183436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 2184436c6c9cSStella Laurenzo 2185436c6c9cSStella Laurenzo PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 2186436c6c9cSStella Laurenzo : ownedName(new std::string(std::move(ownedName))) { 2187436c6c9cSStella Laurenzo namedAttr = mlirNamedAttributeGet( 2188436c6c9cSStella Laurenzo mlirIdentifierGet(mlirAttributeGetContext(attr), 2189436c6c9cSStella Laurenzo toMlirStringRef(*this->ownedName)), 2190436c6c9cSStella Laurenzo attr); 2191436c6c9cSStella Laurenzo } 2192436c6c9cSStella Laurenzo 2193436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 2194436c6c9cSStella Laurenzo // PyType. 2195436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 2196436c6c9cSStella Laurenzo 2197e6d738e0SRahul Kayaith bool PyType::operator==(const PyType &other) const { 2198436c6c9cSStella Laurenzo return mlirTypeEqual(type, other.type); 2199436c6c9cSStella Laurenzo } 2200436c6c9cSStella Laurenzo 2201b56d1ec6SPeter Hawkins nb::object PyType::getCapsule() { 2202b56d1ec6SPeter Hawkins return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this)); 2203436c6c9cSStella Laurenzo } 2204436c6c9cSStella Laurenzo 2205b56d1ec6SPeter Hawkins PyType PyType::createFromCapsule(nb::object capsule) { 2206436c6c9cSStella Laurenzo MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 2207436c6c9cSStella Laurenzo if (mlirTypeIsNull(rawType)) 2208b56d1ec6SPeter Hawkins throw nb::python_error(); 2209436c6c9cSStella Laurenzo return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 2210436c6c9cSStella Laurenzo rawType); 2211436c6c9cSStella Laurenzo } 2212436c6c9cSStella Laurenzo 2213436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 2214d39a7844Smax // PyTypeID. 2215d39a7844Smax //------------------------------------------------------------------------------ 2216d39a7844Smax 2217b56d1ec6SPeter Hawkins nb::object PyTypeID::getCapsule() { 2218b56d1ec6SPeter Hawkins return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this)); 2219d39a7844Smax } 2220d39a7844Smax 2221b56d1ec6SPeter Hawkins PyTypeID PyTypeID::createFromCapsule(nb::object capsule) { 2222d39a7844Smax MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr()); 2223d39a7844Smax if (mlirTypeIDIsNull(mlirTypeID)) 2224b56d1ec6SPeter Hawkins throw nb::python_error(); 2225d39a7844Smax return PyTypeID(mlirTypeID); 2226d39a7844Smax } 2227d39a7844Smax bool PyTypeID::operator==(const PyTypeID &other) const { 2228d39a7844Smax return mlirTypeIDEqual(typeID, other.typeID); 2229d39a7844Smax } 2230d39a7844Smax 2231d39a7844Smax //------------------------------------------------------------------------------ 22327c850867SMaksim Levental // PyValue and subclasses. 2233436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 2234436c6c9cSStella Laurenzo 2235b56d1ec6SPeter Hawkins nb::object PyValue::getCapsule() { 2236b56d1ec6SPeter Hawkins return nb::steal<nb::object>(mlirPythonValueToCapsule(get())); 22373f3d1c90SMike Urbach } 22383f3d1c90SMike Urbach 2239b56d1ec6SPeter Hawkins nb::object PyValue::maybeDownCast() { 22407c850867SMaksim Levental MlirType type = mlirValueGetType(get()); 22417c850867SMaksim Levental MlirTypeID mlirTypeID = mlirTypeGetTypeID(type); 22427c850867SMaksim Levental assert(!mlirTypeIDIsNull(mlirTypeID) && 22437c850867SMaksim Levental "mlirTypeID was expected to be non-null."); 2244b56d1ec6SPeter Hawkins std::optional<nb::callable> valueCaster = 22457c850867SMaksim Levental PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type)); 2246b56d1ec6SPeter Hawkins // nb::rv_policy::move means use std::move to move the return value 22477c850867SMaksim Levental // contents into a new instance that will be owned by Python. 2248b56d1ec6SPeter Hawkins nb::object thisObj = nb::cast(this, nb::rv_policy::move); 22497c850867SMaksim Levental if (!valueCaster) 22507c850867SMaksim Levental return thisObj; 22517c850867SMaksim Levental return valueCaster.value()(thisObj); 22527c850867SMaksim Levental } 22537c850867SMaksim Levental 2254b56d1ec6SPeter Hawkins PyValue PyValue::createFromCapsule(nb::object capsule) { 22553f3d1c90SMike Urbach MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 22563f3d1c90SMike Urbach if (mlirValueIsNull(value)) 2257b56d1ec6SPeter Hawkins throw nb::python_error(); 22583f3d1c90SMike Urbach MlirOperation owner; 22593f3d1c90SMike Urbach if (mlirValueIsAOpResult(value)) 22603f3d1c90SMike Urbach owner = mlirOpResultGetOwner(value); 22613f3d1c90SMike Urbach if (mlirValueIsABlockArgument(value)) 22623f3d1c90SMike Urbach owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 22633f3d1c90SMike Urbach if (mlirOperationIsNull(owner)) 2264b56d1ec6SPeter Hawkins throw nb::python_error(); 22653f3d1c90SMike Urbach MlirContext ctx = mlirOperationGetContext(owner); 22663f3d1c90SMike Urbach PyOperationRef ownerRef = 22673f3d1c90SMike Urbach PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 22683f3d1c90SMike Urbach return PyValue(ownerRef, value); 22693f3d1c90SMike Urbach } 22703f3d1c90SMike Urbach 227130d61893SAlex Zinenko //------------------------------------------------------------------------------ 227230d61893SAlex Zinenko // PySymbolTable. 227330d61893SAlex Zinenko //------------------------------------------------------------------------------ 227430d61893SAlex Zinenko 227530d61893SAlex Zinenko PySymbolTable::PySymbolTable(PyOperationBase &operation) 227630d61893SAlex Zinenko : operation(operation.getOperation().getRef()) { 227730d61893SAlex Zinenko symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 227830d61893SAlex Zinenko if (mlirSymbolTableIsNull(symbolTable)) { 2279b56d1ec6SPeter Hawkins throw nb::type_error("Operation is not a Symbol Table."); 228030d61893SAlex Zinenko } 228130d61893SAlex Zinenko } 228230d61893SAlex Zinenko 2283b56d1ec6SPeter Hawkins nb::object PySymbolTable::dunderGetItem(const std::string &name) { 228430d61893SAlex Zinenko operation->checkValid(); 228530d61893SAlex Zinenko MlirOperation symbol = mlirSymbolTableLookup( 228630d61893SAlex Zinenko symbolTable, mlirStringRefCreate(name.data(), name.length())); 228730d61893SAlex Zinenko if (mlirOperationIsNull(symbol)) 2288b56d1ec6SPeter Hawkins throw nb::key_error( 2289b56d1ec6SPeter Hawkins ("Symbol '" + name + "' not in the symbol table.").c_str()); 229030d61893SAlex Zinenko 229130d61893SAlex Zinenko return PyOperation::forOperation(operation->getContext(), symbol, 229230d61893SAlex Zinenko operation.getObject()) 229330d61893SAlex Zinenko ->createOpView(); 229430d61893SAlex Zinenko } 229530d61893SAlex Zinenko 229630d61893SAlex Zinenko void PySymbolTable::erase(PyOperationBase &symbol) { 229730d61893SAlex Zinenko operation->checkValid(); 229830d61893SAlex Zinenko symbol.getOperation().checkValid(); 229930d61893SAlex Zinenko mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 230030d61893SAlex Zinenko // The operation is also erased, so we must invalidate it. There may be Python 230130d61893SAlex Zinenko // references to this operation so we don't want to delete it from the list of 230230d61893SAlex Zinenko // live operations here. 230330d61893SAlex Zinenko symbol.getOperation().valid = false; 230430d61893SAlex Zinenko } 230530d61893SAlex Zinenko 230630d61893SAlex Zinenko void PySymbolTable::dunderDel(const std::string &name) { 2307b56d1ec6SPeter Hawkins nb::object operation = dunderGetItem(name); 2308b56d1ec6SPeter Hawkins erase(nb::cast<PyOperationBase &>(operation)); 230930d61893SAlex Zinenko } 231030d61893SAlex Zinenko 2311974c1596SRahul Kayaith MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { 231230d61893SAlex Zinenko operation->checkValid(); 231330d61893SAlex Zinenko symbol.getOperation().checkValid(); 231430d61893SAlex Zinenko MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 231530d61893SAlex Zinenko symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 231630d61893SAlex Zinenko if (mlirAttributeIsNull(symbolAttr)) 2317b56d1ec6SPeter Hawkins throw nb::value_error("Expected operation to have a symbol name."); 2318974c1596SRahul Kayaith return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()); 231930d61893SAlex Zinenko } 232030d61893SAlex Zinenko 2321974c1596SRahul Kayaith MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { 2322bdc31837SStella Laurenzo // Op must already be a symbol. 2323bdc31837SStella Laurenzo PyOperation &operation = symbol.getOperation(); 2324bdc31837SStella Laurenzo operation.checkValid(); 2325bdc31837SStella Laurenzo MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 2326bdc31837SStella Laurenzo MlirAttribute existingNameAttr = 2327bdc31837SStella Laurenzo mlirOperationGetAttributeByName(operation.get(), attrName); 2328bdc31837SStella Laurenzo if (mlirAttributeIsNull(existingNameAttr)) 2329b56d1ec6SPeter Hawkins throw nb::value_error("Expected operation to have a symbol name."); 2330974c1596SRahul Kayaith return existingNameAttr; 2331bdc31837SStella Laurenzo } 2332bdc31837SStella Laurenzo 2333bdc31837SStella Laurenzo void PySymbolTable::setSymbolName(PyOperationBase &symbol, 2334bdc31837SStella Laurenzo const std::string &name) { 2335bdc31837SStella Laurenzo // Op must already be a symbol. 2336bdc31837SStella Laurenzo PyOperation &operation = symbol.getOperation(); 2337bdc31837SStella Laurenzo operation.checkValid(); 2338bdc31837SStella Laurenzo MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 2339bdc31837SStella Laurenzo MlirAttribute existingNameAttr = 2340bdc31837SStella Laurenzo mlirOperationGetAttributeByName(operation.get(), attrName); 2341bdc31837SStella Laurenzo if (mlirAttributeIsNull(existingNameAttr)) 2342b56d1ec6SPeter Hawkins throw nb::value_error("Expected operation to have a symbol name."); 2343bdc31837SStella Laurenzo MlirAttribute newNameAttr = 2344bdc31837SStella Laurenzo mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); 2345bdc31837SStella Laurenzo mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); 2346bdc31837SStella Laurenzo } 2347bdc31837SStella Laurenzo 2348974c1596SRahul Kayaith MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { 2349bdc31837SStella Laurenzo PyOperation &operation = symbol.getOperation(); 2350bdc31837SStella Laurenzo operation.checkValid(); 2351bdc31837SStella Laurenzo MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 2352bdc31837SStella Laurenzo MlirAttribute existingVisAttr = 2353bdc31837SStella Laurenzo mlirOperationGetAttributeByName(operation.get(), attrName); 2354bdc31837SStella Laurenzo if (mlirAttributeIsNull(existingVisAttr)) 2355b56d1ec6SPeter Hawkins throw nb::value_error("Expected operation to have a symbol visibility."); 2356974c1596SRahul Kayaith return existingVisAttr; 2357bdc31837SStella Laurenzo } 2358bdc31837SStella Laurenzo 2359bdc31837SStella Laurenzo void PySymbolTable::setVisibility(PyOperationBase &symbol, 2360bdc31837SStella Laurenzo const std::string &visibility) { 2361bdc31837SStella Laurenzo if (visibility != "public" && visibility != "private" && 2362bdc31837SStella Laurenzo visibility != "nested") 2363b56d1ec6SPeter Hawkins throw nb::value_error( 2364bdc31837SStella Laurenzo "Expected visibility to be 'public', 'private' or 'nested'"); 2365bdc31837SStella Laurenzo PyOperation &operation = symbol.getOperation(); 2366bdc31837SStella Laurenzo operation.checkValid(); 2367bdc31837SStella Laurenzo MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 2368bdc31837SStella Laurenzo MlirAttribute existingVisAttr = 2369bdc31837SStella Laurenzo mlirOperationGetAttributeByName(operation.get(), attrName); 2370bdc31837SStella Laurenzo if (mlirAttributeIsNull(existingVisAttr)) 2371b56d1ec6SPeter Hawkins throw nb::value_error("Expected operation to have a symbol visibility."); 2372bdc31837SStella Laurenzo MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), 2373bdc31837SStella Laurenzo toMlirStringRef(visibility)); 2374bdc31837SStella Laurenzo mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); 2375bdc31837SStella Laurenzo } 2376bdc31837SStella Laurenzo 2377bdc31837SStella Laurenzo void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, 2378bdc31837SStella Laurenzo const std::string &newSymbol, 2379bdc31837SStella Laurenzo PyOperationBase &from) { 2380bdc31837SStella Laurenzo PyOperation &fromOperation = from.getOperation(); 2381bdc31837SStella Laurenzo fromOperation.checkValid(); 2382bdc31837SStella Laurenzo if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( 2383bdc31837SStella Laurenzo toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), 2384bdc31837SStella Laurenzo from.getOperation()))) 2385bdc31837SStella Laurenzo 2386b56d1ec6SPeter Hawkins throw nb::value_error("Symbol rename failed"); 2387bdc31837SStella Laurenzo } 2388bdc31837SStella Laurenzo 2389bdc31837SStella Laurenzo void PySymbolTable::walkSymbolTables(PyOperationBase &from, 2390bdc31837SStella Laurenzo bool allSymUsesVisible, 2391b56d1ec6SPeter Hawkins nb::object callback) { 2392bdc31837SStella Laurenzo PyOperation &fromOperation = from.getOperation(); 2393bdc31837SStella Laurenzo fromOperation.checkValid(); 2394bdc31837SStella Laurenzo struct UserData { 2395bdc31837SStella Laurenzo PyMlirContextRef context; 2396b56d1ec6SPeter Hawkins nb::object callback; 2397bdc31837SStella Laurenzo bool gotException; 2398bdc31837SStella Laurenzo std::string exceptionWhat; 2399b56d1ec6SPeter Hawkins nb::object exceptionType; 2400bdc31837SStella Laurenzo }; 2401bdc31837SStella Laurenzo UserData userData{ 2402bdc31837SStella Laurenzo fromOperation.getContext(), std::move(callback), false, {}, {}}; 2403bdc31837SStella Laurenzo mlirSymbolTableWalkSymbolTables( 2404bdc31837SStella Laurenzo fromOperation.get(), allSymUsesVisible, 2405bdc31837SStella Laurenzo [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { 2406bdc31837SStella Laurenzo UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid); 2407bdc31837SStella Laurenzo auto pyFoundOp = 2408bdc31837SStella Laurenzo PyOperation::forOperation(calleeUserData->context, foundOp); 2409bdc31837SStella Laurenzo if (calleeUserData->gotException) 2410bdc31837SStella Laurenzo return; 2411bdc31837SStella Laurenzo try { 2412bdc31837SStella Laurenzo calleeUserData->callback(pyFoundOp.getObject(), isVisible); 2413b56d1ec6SPeter Hawkins } catch (nb::python_error &e) { 2414bdc31837SStella Laurenzo calleeUserData->gotException = true; 2415bdc31837SStella Laurenzo calleeUserData->exceptionWhat = e.what(); 2416b56d1ec6SPeter Hawkins calleeUserData->exceptionType = nb::borrow(e.type()); 2417bdc31837SStella Laurenzo } 2418bdc31837SStella Laurenzo }, 2419bdc31837SStella Laurenzo static_cast<void *>(&userData)); 2420bdc31837SStella Laurenzo if (userData.gotException) { 2421bdc31837SStella Laurenzo std::string message("Exception raised in callback: "); 2422bdc31837SStella Laurenzo message.append(userData.exceptionWhat); 2423337c937dSMehdi Amini throw std::runtime_error(message); 2424bdc31837SStella Laurenzo } 2425bdc31837SStella Laurenzo } 2426bdc31837SStella Laurenzo 2427436c6c9cSStella Laurenzo namespace { 2428436c6c9cSStella Laurenzo 2429436c6c9cSStella Laurenzo /// Python wrapper for MlirBlockArgument. 2430436c6c9cSStella Laurenzo class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 2431436c6c9cSStella Laurenzo public: 2432436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 2433436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "BlockArgument"; 2434436c6c9cSStella Laurenzo using PyConcreteValue::PyConcreteValue; 2435436c6c9cSStella Laurenzo 2436436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 2437b56d1ec6SPeter Hawkins c.def_prop_ro("owner", [](PyBlockArgument &self) { 2438436c6c9cSStella Laurenzo return PyBlock(self.getParentOperation(), 2439436c6c9cSStella Laurenzo mlirBlockArgumentGetOwner(self.get())); 2440436c6c9cSStella Laurenzo }); 2441b56d1ec6SPeter Hawkins c.def_prop_ro("arg_number", [](PyBlockArgument &self) { 2442436c6c9cSStella Laurenzo return mlirBlockArgumentGetArgNumber(self.get()); 2443436c6c9cSStella Laurenzo }); 2444a6e7d024SStella Laurenzo c.def( 2445a6e7d024SStella Laurenzo "set_type", 2446a6e7d024SStella Laurenzo [](PyBlockArgument &self, PyType type) { 2447436c6c9cSStella Laurenzo return mlirBlockArgumentSetType(self.get(), type); 2448a6e7d024SStella Laurenzo }, 2449b56d1ec6SPeter Hawkins nb::arg("type")); 2450436c6c9cSStella Laurenzo } 2451436c6c9cSStella Laurenzo }; 2452436c6c9cSStella Laurenzo 2453436c6c9cSStella Laurenzo /// A list of block arguments. Internally, these are stored as consecutive 2454436c6c9cSStella Laurenzo /// elements, random access is cheap. The argument list is associated with the 2455436c6c9cSStella Laurenzo /// operation that contains the block (detached blocks are not allowed in 2456436c6c9cSStella Laurenzo /// Python bindings) and extends its lifetime. 2457afeda4b9SAlex Zinenko class PyBlockArgumentList 2458afeda4b9SAlex Zinenko : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 2459436c6c9cSStella Laurenzo public: 2460afeda4b9SAlex Zinenko static constexpr const char *pyClassName = "BlockArgumentList"; 24617c850867SMaksim Levental using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>; 2462436c6c9cSStella Laurenzo 2463afeda4b9SAlex Zinenko PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 2464afeda4b9SAlex Zinenko intptr_t startIndex = 0, intptr_t length = -1, 2465afeda4b9SAlex Zinenko intptr_t step = 1) 2466afeda4b9SAlex Zinenko : Sliceable(startIndex, 2467afeda4b9SAlex Zinenko length == -1 ? mlirBlockGetNumArguments(block) : length, 2468afeda4b9SAlex Zinenko step), 2469afeda4b9SAlex Zinenko operation(std::move(operation)), block(block) {} 2470afeda4b9SAlex Zinenko 2471ee168fb9SAlex Zinenko static void bindDerived(ClassTy &c) { 2472b56d1ec6SPeter Hawkins c.def_prop_ro("types", [](PyBlockArgumentList &self) { 2473ee168fb9SAlex Zinenko return getValueTypes(self, self.operation->getContext()); 2474ee168fb9SAlex Zinenko }); 2475ee168fb9SAlex Zinenko } 2476ee168fb9SAlex Zinenko 2477ee168fb9SAlex Zinenko private: 2478ee168fb9SAlex Zinenko /// Give the parent CRTP class access to hook implementations below. 2479ee168fb9SAlex Zinenko friend class Sliceable<PyBlockArgumentList, PyBlockArgument>; 2480ee168fb9SAlex Zinenko 2481afeda4b9SAlex Zinenko /// Returns the number of arguments in the list. 2482ee168fb9SAlex Zinenko intptr_t getRawNumElements() { 2483436c6c9cSStella Laurenzo operation->checkValid(); 2484436c6c9cSStella Laurenzo return mlirBlockGetNumArguments(block); 2485436c6c9cSStella Laurenzo } 2486436c6c9cSStella Laurenzo 2487ee168fb9SAlex Zinenko /// Returns `pos`-the element in the list. 2488ee168fb9SAlex Zinenko PyBlockArgument getRawElement(intptr_t pos) { 2489afeda4b9SAlex Zinenko MlirValue argument = mlirBlockGetArgument(block, pos); 2490afeda4b9SAlex Zinenko return PyBlockArgument(operation, argument); 2491436c6c9cSStella Laurenzo } 2492436c6c9cSStella Laurenzo 2493afeda4b9SAlex Zinenko /// Returns a sublist of this list. 2494afeda4b9SAlex Zinenko PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 2495afeda4b9SAlex Zinenko intptr_t step) { 2496afeda4b9SAlex Zinenko return PyBlockArgumentList(operation, block, startIndex, length, step); 2497436c6c9cSStella Laurenzo } 2498436c6c9cSStella Laurenzo 2499436c6c9cSStella Laurenzo PyOperationRef operation; 2500436c6c9cSStella Laurenzo MlirBlock block; 2501436c6c9cSStella Laurenzo }; 2502436c6c9cSStella Laurenzo 2503436c6c9cSStella Laurenzo /// A list of operation operands. Internally, these are stored as consecutive 2504d7e49736SMaksim Levental /// elements, random access is cheap. The (returned) operand list is associated 2505d7e49736SMaksim Levental /// with the operation whose operands these are, and thus extends the lifetime 2506d7e49736SMaksim Levental /// of this operation. 2507436c6c9cSStella Laurenzo class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 2508436c6c9cSStella Laurenzo public: 2509436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "OpOperandList"; 25107c850867SMaksim Levental using SliceableT = Sliceable<PyOpOperandList, PyValue>; 2511436c6c9cSStella Laurenzo 2512436c6c9cSStella Laurenzo PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 2513436c6c9cSStella Laurenzo intptr_t length = -1, intptr_t step = 1) 2514436c6c9cSStella Laurenzo : Sliceable(startIndex, 2515436c6c9cSStella Laurenzo length == -1 ? mlirOperationGetNumOperands(operation->get()) 2516436c6c9cSStella Laurenzo : length, 2517436c6c9cSStella Laurenzo step), 2518436c6c9cSStella Laurenzo operation(operation) {} 2519436c6c9cSStella Laurenzo 2520ee168fb9SAlex Zinenko void dunderSetItem(intptr_t index, PyValue value) { 2521ee168fb9SAlex Zinenko index = wrapIndex(index); 2522ee168fb9SAlex Zinenko mlirOperationSetOperand(operation->get(), index, value.get()); 2523ee168fb9SAlex Zinenko } 2524ee168fb9SAlex Zinenko 2525ee168fb9SAlex Zinenko static void bindDerived(ClassTy &c) { 2526ee168fb9SAlex Zinenko c.def("__setitem__", &PyOpOperandList::dunderSetItem); 2527ee168fb9SAlex Zinenko } 2528ee168fb9SAlex Zinenko 2529ee168fb9SAlex Zinenko private: 2530ee168fb9SAlex Zinenko /// Give the parent CRTP class access to hook implementations below. 2531ee168fb9SAlex Zinenko friend class Sliceable<PyOpOperandList, PyValue>; 2532ee168fb9SAlex Zinenko 2533ee168fb9SAlex Zinenko intptr_t getRawNumElements() { 2534436c6c9cSStella Laurenzo operation->checkValid(); 2535436c6c9cSStella Laurenzo return mlirOperationGetNumOperands(operation->get()); 2536436c6c9cSStella Laurenzo } 2537436c6c9cSStella Laurenzo 2538ee168fb9SAlex Zinenko PyValue getRawElement(intptr_t pos) { 25395664c5e2SJohn Demme MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 25405664c5e2SJohn Demme MlirOperation owner; 25415664c5e2SJohn Demme if (mlirValueIsAOpResult(operand)) 25425664c5e2SJohn Demme owner = mlirOpResultGetOwner(operand); 25435664c5e2SJohn Demme else if (mlirValueIsABlockArgument(operand)) 25445664c5e2SJohn Demme owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 25455664c5e2SJohn Demme else 25465664c5e2SJohn Demme assert(false && "Value must be an block arg or op result."); 25475664c5e2SJohn Demme PyOperationRef pyOwner = 25485664c5e2SJohn Demme PyOperation::forOperation(operation->getContext(), owner); 25495664c5e2SJohn Demme return PyValue(pyOwner, operand); 2550436c6c9cSStella Laurenzo } 2551436c6c9cSStella Laurenzo 2552436c6c9cSStella Laurenzo PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2553436c6c9cSStella Laurenzo return PyOpOperandList(operation, startIndex, length, step); 2554436c6c9cSStella Laurenzo } 2555436c6c9cSStella Laurenzo 2556436c6c9cSStella Laurenzo PyOperationRef operation; 2557436c6c9cSStella Laurenzo }; 2558436c6c9cSStella Laurenzo 2559d7e49736SMaksim Levental /// A list of operation successors. Internally, these are stored as consecutive 2560d7e49736SMaksim Levental /// elements, random access is cheap. The (returned) successor list is 2561d7e49736SMaksim Levental /// associated with the operation whose successors these are, and thus extends 2562d7e49736SMaksim Levental /// the lifetime of this operation. 2563d7e49736SMaksim Levental class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> { 2564d7e49736SMaksim Levental public: 2565d7e49736SMaksim Levental static constexpr const char *pyClassName = "OpSuccessors"; 2566d7e49736SMaksim Levental 2567d7e49736SMaksim Levental PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0, 2568d7e49736SMaksim Levental intptr_t length = -1, intptr_t step = 1) 2569d7e49736SMaksim Levental : Sliceable(startIndex, 2570d7e49736SMaksim Levental length == -1 ? mlirOperationGetNumSuccessors(operation->get()) 2571d7e49736SMaksim Levental : length, 2572d7e49736SMaksim Levental step), 2573d7e49736SMaksim Levental operation(operation) {} 2574d7e49736SMaksim Levental 2575d7e49736SMaksim Levental void dunderSetItem(intptr_t index, PyBlock block) { 2576d7e49736SMaksim Levental index = wrapIndex(index); 2577d7e49736SMaksim Levental mlirOperationSetSuccessor(operation->get(), index, block.get()); 2578d7e49736SMaksim Levental } 2579d7e49736SMaksim Levental 2580d7e49736SMaksim Levental static void bindDerived(ClassTy &c) { 2581d7e49736SMaksim Levental c.def("__setitem__", &PyOpSuccessors::dunderSetItem); 2582d7e49736SMaksim Levental } 2583d7e49736SMaksim Levental 2584d7e49736SMaksim Levental private: 2585d7e49736SMaksim Levental /// Give the parent CRTP class access to hook implementations below. 2586d7e49736SMaksim Levental friend class Sliceable<PyOpSuccessors, PyBlock>; 2587d7e49736SMaksim Levental 2588d7e49736SMaksim Levental intptr_t getRawNumElements() { 2589d7e49736SMaksim Levental operation->checkValid(); 2590d7e49736SMaksim Levental return mlirOperationGetNumSuccessors(operation->get()); 2591d7e49736SMaksim Levental } 2592d7e49736SMaksim Levental 2593d7e49736SMaksim Levental PyBlock getRawElement(intptr_t pos) { 2594d7e49736SMaksim Levental MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos); 2595d7e49736SMaksim Levental return PyBlock(operation, block); 2596d7e49736SMaksim Levental } 2597d7e49736SMaksim Levental 2598d7e49736SMaksim Levental PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2599d7e49736SMaksim Levental return PyOpSuccessors(operation, startIndex, length, step); 2600d7e49736SMaksim Levental } 2601d7e49736SMaksim Levental 2602d7e49736SMaksim Levental PyOperationRef operation; 2603d7e49736SMaksim Levental }; 2604d7e49736SMaksim Levental 2605436c6c9cSStella Laurenzo /// A list of operation attributes. Can be indexed by name, producing 2606436c6c9cSStella Laurenzo /// attributes, or by index, producing named attributes. 2607436c6c9cSStella Laurenzo class PyOpAttributeMap { 2608436c6c9cSStella Laurenzo public: 26091fc096afSMehdi Amini PyOpAttributeMap(PyOperationRef operation) 26101fc096afSMehdi Amini : operation(std::move(operation)) {} 2611436c6c9cSStella Laurenzo 2612974c1596SRahul Kayaith MlirAttribute dunderGetItemNamed(const std::string &name) { 2613436c6c9cSStella Laurenzo MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 2614436c6c9cSStella Laurenzo toMlirStringRef(name)); 2615436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 2616b56d1ec6SPeter Hawkins throw nb::key_error("attempt to access a non-existent attribute"); 2617436c6c9cSStella Laurenzo } 2618974c1596SRahul Kayaith return attr; 2619436c6c9cSStella Laurenzo } 2620436c6c9cSStella Laurenzo 2621436c6c9cSStella Laurenzo PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 2622436c6c9cSStella Laurenzo if (index < 0 || index >= dunderLen()) { 2623b56d1ec6SPeter Hawkins throw nb::index_error("attempt to access out of bounds attribute"); 2624436c6c9cSStella Laurenzo } 2625436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr = 2626436c6c9cSStella Laurenzo mlirOperationGetAttribute(operation->get(), index); 2627436c6c9cSStella Laurenzo return PyNamedAttribute( 2628436c6c9cSStella Laurenzo namedAttr.attribute, 2629120591e1SRiver Riddle std::string(mlirIdentifierStr(namedAttr.name).data, 2630120591e1SRiver Riddle mlirIdentifierStr(namedAttr.name).length)); 2631436c6c9cSStella Laurenzo } 2632436c6c9cSStella Laurenzo 26331fc096afSMehdi Amini void dunderSetItem(const std::string &name, const PyAttribute &attr) { 2634436c6c9cSStella Laurenzo mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 2635436c6c9cSStella Laurenzo attr); 2636436c6c9cSStella Laurenzo } 2637436c6c9cSStella Laurenzo 2638436c6c9cSStella Laurenzo void dunderDelItem(const std::string &name) { 2639436c6c9cSStella Laurenzo int removed = mlirOperationRemoveAttributeByName(operation->get(), 2640436c6c9cSStella Laurenzo toMlirStringRef(name)); 2641436c6c9cSStella Laurenzo if (!removed) 2642b56d1ec6SPeter Hawkins throw nb::key_error("attempt to delete a non-existent attribute"); 2643436c6c9cSStella Laurenzo } 2644436c6c9cSStella Laurenzo 2645436c6c9cSStella Laurenzo intptr_t dunderLen() { 2646436c6c9cSStella Laurenzo return mlirOperationGetNumAttributes(operation->get()); 2647436c6c9cSStella Laurenzo } 2648436c6c9cSStella Laurenzo 2649436c6c9cSStella Laurenzo bool dunderContains(const std::string &name) { 2650436c6c9cSStella Laurenzo return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 2651436c6c9cSStella Laurenzo operation->get(), toMlirStringRef(name))); 2652436c6c9cSStella Laurenzo } 2653436c6c9cSStella Laurenzo 2654b56d1ec6SPeter Hawkins static void bind(nb::module_ &m) { 2655b56d1ec6SPeter Hawkins nb::class_<PyOpAttributeMap>(m, "OpAttributeMap") 2656436c6c9cSStella Laurenzo .def("__contains__", &PyOpAttributeMap::dunderContains) 2657436c6c9cSStella Laurenzo .def("__len__", &PyOpAttributeMap::dunderLen) 2658436c6c9cSStella Laurenzo .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 2659436c6c9cSStella Laurenzo .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 2660436c6c9cSStella Laurenzo .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 2661436c6c9cSStella Laurenzo .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 2662436c6c9cSStella Laurenzo } 2663436c6c9cSStella Laurenzo 2664436c6c9cSStella Laurenzo private: 2665436c6c9cSStella Laurenzo PyOperationRef operation; 2666436c6c9cSStella Laurenzo }; 2667436c6c9cSStella Laurenzo 2668be0a7e9fSMehdi Amini } // namespace 2669436c6c9cSStella Laurenzo 2670436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 2671436c6c9cSStella Laurenzo // Populates the core exports of the 'ir' submodule. 2672436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 2673436c6c9cSStella Laurenzo 2674b56d1ec6SPeter Hawkins void mlir::python::populateIRCore(nb::module_ &m) { 267508e2c15aSMaksim Levental // disable leak warnings which tend to be false positives. 267608e2c15aSMaksim Levental nb::set_leak_warnings(false); 2677436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 26787ee25bc5SStella Laurenzo // Enums. 26797ee25bc5SStella Laurenzo //---------------------------------------------------------------------------- 2680b56d1ec6SPeter Hawkins nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity") 26817ee25bc5SStella Laurenzo .value("ERROR", MlirDiagnosticError) 26827ee25bc5SStella Laurenzo .value("WARNING", MlirDiagnosticWarning) 26837ee25bc5SStella Laurenzo .value("NOTE", MlirDiagnosticNote) 26847ee25bc5SStella Laurenzo .value("REMARK", MlirDiagnosticRemark); 26857ee25bc5SStella Laurenzo 2686b56d1ec6SPeter Hawkins nb::enum_<MlirWalkOrder>(m, "WalkOrder") 268747148832SHideto Ueno .value("PRE_ORDER", MlirWalkPreOrder) 268847148832SHideto Ueno .value("POST_ORDER", MlirWalkPostOrder); 268947148832SHideto Ueno 2690b56d1ec6SPeter Hawkins nb::enum_<MlirWalkResult>(m, "WalkResult") 269147148832SHideto Ueno .value("ADVANCE", MlirWalkResultAdvance) 269247148832SHideto Ueno .value("INTERRUPT", MlirWalkResultInterrupt) 269347148832SHideto Ueno .value("SKIP", MlirWalkResultSkip); 269447148832SHideto Ueno 26957ee25bc5SStella Laurenzo //---------------------------------------------------------------------------- 26967ee25bc5SStella Laurenzo // Mapping of Diagnostics. 26977ee25bc5SStella Laurenzo //---------------------------------------------------------------------------- 2698b56d1ec6SPeter Hawkins nb::class_<PyDiagnostic>(m, "Diagnostic") 2699b56d1ec6SPeter Hawkins .def_prop_ro("severity", &PyDiagnostic::getSeverity) 2700b56d1ec6SPeter Hawkins .def_prop_ro("location", &PyDiagnostic::getLocation) 2701b56d1ec6SPeter Hawkins .def_prop_ro("message", &PyDiagnostic::getMessage) 2702b56d1ec6SPeter Hawkins .def_prop_ro("notes", &PyDiagnostic::getNotes) 2703b56d1ec6SPeter Hawkins .def("__str__", [](PyDiagnostic &self) -> nb::str { 27047ee25bc5SStella Laurenzo if (!self.isValid()) 2705b56d1ec6SPeter Hawkins return nb::str("<Invalid Diagnostic>"); 27067ee25bc5SStella Laurenzo return self.getMessage(); 27077ee25bc5SStella Laurenzo }); 27087ee25bc5SStella Laurenzo 2709b56d1ec6SPeter Hawkins nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo") 2710b56d1ec6SPeter Hawkins .def("__init__", 2711b56d1ec6SPeter Hawkins [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) { 2712b56d1ec6SPeter Hawkins new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo()); 2713b56d1ec6SPeter Hawkins }) 2714b56d1ec6SPeter Hawkins .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity) 2715b56d1ec6SPeter Hawkins .def_ro("location", &PyDiagnostic::DiagnosticInfo::location) 2716b56d1ec6SPeter Hawkins .def_ro("message", &PyDiagnostic::DiagnosticInfo::message) 2717b56d1ec6SPeter Hawkins .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes) 27183ea4c501SRahul Kayaith .def("__str__", 27193ea4c501SRahul Kayaith [](PyDiagnostic::DiagnosticInfo &self) { return self.message; }); 27203ea4c501SRahul Kayaith 2721b56d1ec6SPeter Hawkins nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler") 27227ee25bc5SStella Laurenzo .def("detach", &PyDiagnosticHandler::detach) 2723b56d1ec6SPeter Hawkins .def_prop_ro("attached", &PyDiagnosticHandler::isAttached) 2724b56d1ec6SPeter Hawkins .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError) 27257ee25bc5SStella Laurenzo .def("__enter__", &PyDiagnosticHandler::contextEnter) 2726b56d1ec6SPeter Hawkins .def("__exit__", &PyDiagnosticHandler::contextExit, 2727b56d1ec6SPeter Hawkins nb::arg("exc_type").none(), nb::arg("exc_value").none(), 2728b56d1ec6SPeter Hawkins nb::arg("traceback").none()); 27297ee25bc5SStella Laurenzo 27307ee25bc5SStella Laurenzo //---------------------------------------------------------------------------- 27314acd8457SAlex Zinenko // Mapping of MlirContext. 27325e83a5b4SStella Laurenzo // Note that this is exported as _BaseContext. The containing, Python level 27335e83a5b4SStella Laurenzo // __init__.py will subclass it with site-specific functionality and set a 27345e83a5b4SStella Laurenzo // "Context" attribute on this module. 2735436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2736b56d1ec6SPeter Hawkins nb::class_<PyMlirContext>(m, "_BaseContext") 2737b56d1ec6SPeter Hawkins .def("__init__", 2738b56d1ec6SPeter Hawkins [](PyMlirContext &self) { 2739b56d1ec6SPeter Hawkins MlirContext context = mlirContextCreateWithThreading(false); 2740b56d1ec6SPeter Hawkins new (&self) PyMlirContext(context); 2741b56d1ec6SPeter Hawkins }) 2742436c6c9cSStella Laurenzo .def_static("_get_live_count", &PyMlirContext::getLiveCount) 2743436c6c9cSStella Laurenzo .def("_get_context_again", 2744436c6c9cSStella Laurenzo [](PyMlirContext &self) { 2745436c6c9cSStella Laurenzo PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 2746436c6c9cSStella Laurenzo return ref.releaseObject(); 2747436c6c9cSStella Laurenzo }) 2748436c6c9cSStella Laurenzo .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 2749d1fdb416SJohn Demme .def("_get_live_operation_objects", 2750d1fdb416SJohn Demme &PyMlirContext::getLiveOperationObjects) 27516b0bed7eSJohn Demme .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) 275291f11611SOleksandr "Alex" Zinenko .def("_clear_live_operations_inside", 2753b56d1ec6SPeter Hawkins nb::overload_cast<MlirOperation>( 275491f11611SOleksandr "Alex" Zinenko &PyMlirContext::clearOperationsInside)) 2755436c6c9cSStella Laurenzo .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 2756b56d1ec6SPeter Hawkins .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) 2757436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 2758436c6c9cSStella Laurenzo .def("__enter__", &PyMlirContext::contextEnter) 2759b56d1ec6SPeter Hawkins .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(), 2760b56d1ec6SPeter Hawkins nb::arg("exc_value").none(), nb::arg("traceback").none()) 2761b56d1ec6SPeter Hawkins .def_prop_ro_static( 2762436c6c9cSStella Laurenzo "current", 2763b56d1ec6SPeter Hawkins [](nb::object & /*class*/) { 2764436c6c9cSStella Laurenzo auto *context = PyThreadContextEntry::getDefaultContext(); 2765436c6c9cSStella Laurenzo if (!context) 2766b56d1ec6SPeter Hawkins return nb::none(); 2767b56d1ec6SPeter Hawkins return nb::cast(context); 2768436c6c9cSStella Laurenzo }, 2769436c6c9cSStella Laurenzo "Gets the Context bound to the current thread or raises ValueError") 2770b56d1ec6SPeter Hawkins .def_prop_ro( 2771436c6c9cSStella Laurenzo "dialects", 2772436c6c9cSStella Laurenzo [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2773436c6c9cSStella Laurenzo "Gets a container for accessing dialects by name") 2774b56d1ec6SPeter Hawkins .def_prop_ro( 2775436c6c9cSStella Laurenzo "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2776436c6c9cSStella Laurenzo "Alias for 'dialect'") 2777436c6c9cSStella Laurenzo .def( 2778436c6c9cSStella Laurenzo "get_dialect_descriptor", 2779436c6c9cSStella Laurenzo [=](PyMlirContext &self, std::string &name) { 2780436c6c9cSStella Laurenzo MlirDialect dialect = mlirContextGetOrLoadDialect( 2781436c6c9cSStella Laurenzo self.get(), {name.data(), name.size()}); 2782436c6c9cSStella Laurenzo if (mlirDialectIsNull(dialect)) { 2783b56d1ec6SPeter Hawkins throw nb::value_error( 2784b56d1ec6SPeter Hawkins (Twine("Dialect '") + name + "' not found").str().c_str()); 2785436c6c9cSStella Laurenzo } 2786436c6c9cSStella Laurenzo return PyDialectDescriptor(self.getRef(), dialect); 2787436c6c9cSStella Laurenzo }, 2788b56d1ec6SPeter Hawkins nb::arg("dialect_name"), 2789436c6c9cSStella Laurenzo "Gets or loads a dialect by name, returning its descriptor object") 2790b56d1ec6SPeter Hawkins .def_prop_rw( 2791436c6c9cSStella Laurenzo "allow_unregistered_dialects", 2792436c6c9cSStella Laurenzo [](PyMlirContext &self) -> bool { 2793436c6c9cSStella Laurenzo return mlirContextGetAllowUnregisteredDialects(self.get()); 2794436c6c9cSStella Laurenzo }, 2795436c6c9cSStella Laurenzo [](PyMlirContext &self, bool value) { 2796436c6c9cSStella Laurenzo mlirContextSetAllowUnregisteredDialects(self.get(), value); 27979a9214faSStella Laurenzo }) 27987ee25bc5SStella Laurenzo .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, 2799b56d1ec6SPeter Hawkins nb::arg("callback"), 28007ee25bc5SStella Laurenzo "Attaches a diagnostic handler that will receive callbacks") 2801a6e7d024SStella Laurenzo .def( 2802a6e7d024SStella Laurenzo "enable_multithreading", 2803caa159f0SNicolas Vasilache [](PyMlirContext &self, bool enable) { 2804caa159f0SNicolas Vasilache mlirContextEnableMultithreading(self.get(), enable); 2805a6e7d024SStella Laurenzo }, 2806b56d1ec6SPeter Hawkins nb::arg("enable")) 2807a6e7d024SStella Laurenzo .def( 2808a6e7d024SStella Laurenzo "is_registered_operation", 28099a9214faSStella Laurenzo [](PyMlirContext &self, std::string &name) { 28109a9214faSStella Laurenzo return mlirContextIsRegisteredOperation( 28119a9214faSStella Laurenzo self.get(), MlirStringRef{name.data(), name.size()}); 2812a6e7d024SStella Laurenzo }, 2813b56d1ec6SPeter Hawkins nb::arg("operation_name")) 28145e83a5b4SStella Laurenzo .def( 28155e83a5b4SStella Laurenzo "append_dialect_registry", 28165e83a5b4SStella Laurenzo [](PyMlirContext &self, PyDialectRegistry ®istry) { 28175e83a5b4SStella Laurenzo mlirContextAppendDialectRegistry(self.get(), registry); 28185e83a5b4SStella Laurenzo }, 2819b56d1ec6SPeter Hawkins nb::arg("registry")) 2820b56d1ec6SPeter Hawkins .def_prop_rw("emit_error_diagnostics", nullptr, 28213ea4c501SRahul Kayaith &PyMlirContext::setEmitErrorDiagnostics, 28223ea4c501SRahul Kayaith "Emit error diagnostics to diagnostic handlers. By default " 28233ea4c501SRahul Kayaith "error diagnostics are captured and reported through " 28243ea4c501SRahul Kayaith "MLIRError exceptions.") 28255e83a5b4SStella Laurenzo .def("load_all_available_dialects", [](PyMlirContext &self) { 28265e83a5b4SStella Laurenzo mlirContextLoadAllAvailableDialects(self.get()); 28275e83a5b4SStella Laurenzo }); 2828436c6c9cSStella Laurenzo 2829436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2830436c6c9cSStella Laurenzo // Mapping of PyDialectDescriptor 2831436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2832b56d1ec6SPeter Hawkins nb::class_<PyDialectDescriptor>(m, "DialectDescriptor") 2833b56d1ec6SPeter Hawkins .def_prop_ro("namespace", 2834436c6c9cSStella Laurenzo [](PyDialectDescriptor &self) { 2835b56d1ec6SPeter Hawkins MlirStringRef ns = mlirDialectGetNamespace(self.get()); 2836b56d1ec6SPeter Hawkins return nb::str(ns.data, ns.length); 2837436c6c9cSStella Laurenzo }) 2838436c6c9cSStella Laurenzo .def("__repr__", [](PyDialectDescriptor &self) { 2839436c6c9cSStella Laurenzo MlirStringRef ns = mlirDialectGetNamespace(self.get()); 2840436c6c9cSStella Laurenzo std::string repr("<DialectDescriptor "); 2841436c6c9cSStella Laurenzo repr.append(ns.data, ns.length); 2842436c6c9cSStella Laurenzo repr.append(">"); 2843436c6c9cSStella Laurenzo return repr; 2844436c6c9cSStella Laurenzo }); 2845436c6c9cSStella Laurenzo 2846436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2847436c6c9cSStella Laurenzo // Mapping of PyDialects 2848436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2849b56d1ec6SPeter Hawkins nb::class_<PyDialects>(m, "Dialects") 2850436c6c9cSStella Laurenzo .def("__getitem__", 2851436c6c9cSStella Laurenzo [=](PyDialects &self, std::string keyName) { 2852436c6c9cSStella Laurenzo MlirDialect dialect = 2853436c6c9cSStella Laurenzo self.getDialectForKey(keyName, /*attrError=*/false); 2854b56d1ec6SPeter Hawkins nb::object descriptor = 2855b56d1ec6SPeter Hawkins nb::cast(PyDialectDescriptor{self.getContext(), dialect}); 2856436c6c9cSStella Laurenzo return createCustomDialectWrapper(keyName, std::move(descriptor)); 2857436c6c9cSStella Laurenzo }) 2858436c6c9cSStella Laurenzo .def("__getattr__", [=](PyDialects &self, std::string attrName) { 2859436c6c9cSStella Laurenzo MlirDialect dialect = 2860436c6c9cSStella Laurenzo self.getDialectForKey(attrName, /*attrError=*/true); 2861b56d1ec6SPeter Hawkins nb::object descriptor = 2862b56d1ec6SPeter Hawkins nb::cast(PyDialectDescriptor{self.getContext(), dialect}); 2863436c6c9cSStella Laurenzo return createCustomDialectWrapper(attrName, std::move(descriptor)); 2864436c6c9cSStella Laurenzo }); 2865436c6c9cSStella Laurenzo 2866436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2867436c6c9cSStella Laurenzo // Mapping of PyDialect 2868436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2869b56d1ec6SPeter Hawkins nb::class_<PyDialect>(m, "Dialect") 2870b56d1ec6SPeter Hawkins .def(nb::init<nb::object>(), nb::arg("descriptor")) 2871b56d1ec6SPeter Hawkins .def_prop_ro("descriptor", 2872b56d1ec6SPeter Hawkins [](PyDialect &self) { return self.getDescriptor(); }) 2873b56d1ec6SPeter Hawkins .def("__repr__", [](nb::object self) { 2874436c6c9cSStella Laurenzo auto clazz = self.attr("__class__"); 2875b56d1ec6SPeter Hawkins return nb::str("<Dialect ") + 2876b56d1ec6SPeter Hawkins self.attr("descriptor").attr("namespace") + nb::str(" (class ") + 2877b56d1ec6SPeter Hawkins clazz.attr("__module__") + nb::str(".") + 2878b56d1ec6SPeter Hawkins clazz.attr("__name__") + nb::str(")>"); 2879436c6c9cSStella Laurenzo }); 2880436c6c9cSStella Laurenzo 2881436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 28825e83a5b4SStella Laurenzo // Mapping of PyDialectRegistry 28835e83a5b4SStella Laurenzo //---------------------------------------------------------------------------- 2884b56d1ec6SPeter Hawkins nb::class_<PyDialectRegistry>(m, "DialectRegistry") 2885b56d1ec6SPeter Hawkins .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule) 28865e83a5b4SStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) 2887b56d1ec6SPeter Hawkins .def(nb::init<>()); 28885e83a5b4SStella Laurenzo 28895e83a5b4SStella Laurenzo //---------------------------------------------------------------------------- 2890436c6c9cSStella Laurenzo // Mapping of Location 2891436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2892b56d1ec6SPeter Hawkins nb::class_<PyLocation>(m, "Location") 2893b56d1ec6SPeter Hawkins .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2894436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2895436c6c9cSStella Laurenzo .def("__enter__", &PyLocation::contextEnter) 2896b56d1ec6SPeter Hawkins .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(), 2897b56d1ec6SPeter Hawkins nb::arg("exc_value").none(), nb::arg("traceback").none()) 2898436c6c9cSStella Laurenzo .def("__eq__", 2899436c6c9cSStella Laurenzo [](PyLocation &self, PyLocation &other) -> bool { 2900436c6c9cSStella Laurenzo return mlirLocationEqual(self, other); 2901436c6c9cSStella Laurenzo }) 2902b56d1ec6SPeter Hawkins .def("__eq__", [](PyLocation &self, nb::object other) { return false; }) 2903b56d1ec6SPeter Hawkins .def_prop_ro_static( 2904436c6c9cSStella Laurenzo "current", 2905b56d1ec6SPeter Hawkins [](nb::object & /*class*/) { 2906436c6c9cSStella Laurenzo auto *loc = PyThreadContextEntry::getDefaultLocation(); 2907436c6c9cSStella Laurenzo if (!loc) 2908b56d1ec6SPeter Hawkins throw nb::value_error("No current Location"); 2909436c6c9cSStella Laurenzo return loc; 2910436c6c9cSStella Laurenzo }, 2911436c6c9cSStella Laurenzo "Gets the Location bound to the current thread or raises ValueError") 2912436c6c9cSStella Laurenzo .def_static( 2913436c6c9cSStella Laurenzo "unknown", 2914436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 2915436c6c9cSStella Laurenzo return PyLocation(context->getRef(), 2916436c6c9cSStella Laurenzo mlirLocationUnknownGet(context->get())); 2917436c6c9cSStella Laurenzo }, 2918b56d1ec6SPeter Hawkins nb::arg("context").none() = nb::none(), 2919436c6c9cSStella Laurenzo "Gets a Location representing an unknown location") 2920436c6c9cSStella Laurenzo .def_static( 2921e67cbbefSJacques Pienaar "callsite", 2922e67cbbefSJacques Pienaar [](PyLocation callee, const std::vector<PyLocation> &frames, 2923e67cbbefSJacques Pienaar DefaultingPyMlirContext context) { 2924e67cbbefSJacques Pienaar if (frames.empty()) 2925b56d1ec6SPeter Hawkins throw nb::value_error("No caller frames provided"); 2926e67cbbefSJacques Pienaar MlirLocation caller = frames.back().get(); 2927e2f16be5SMehdi Amini for (const PyLocation &frame : 2928984b800aSserge-sans-paille llvm::reverse(llvm::ArrayRef(frames).drop_back())) 2929e67cbbefSJacques Pienaar caller = mlirLocationCallSiteGet(frame.get(), caller); 2930e67cbbefSJacques Pienaar return PyLocation(context->getRef(), 2931e67cbbefSJacques Pienaar mlirLocationCallSiteGet(callee.get(), caller)); 2932e67cbbefSJacques Pienaar }, 2933b56d1ec6SPeter Hawkins nb::arg("callee"), nb::arg("frames"), 2934b56d1ec6SPeter Hawkins nb::arg("context").none() = nb::none(), 2935e67cbbefSJacques Pienaar kContextGetCallSiteLocationDocstring) 2936e67cbbefSJacques Pienaar .def_static( 2937436c6c9cSStella Laurenzo "file", 2938436c6c9cSStella Laurenzo [](std::string filename, int line, int col, 2939436c6c9cSStella Laurenzo DefaultingPyMlirContext context) { 2940436c6c9cSStella Laurenzo return PyLocation( 2941436c6c9cSStella Laurenzo context->getRef(), 2942436c6c9cSStella Laurenzo mlirLocationFileLineColGet( 2943436c6c9cSStella Laurenzo context->get(), toMlirStringRef(filename), line, col)); 2944436c6c9cSStella Laurenzo }, 2945b56d1ec6SPeter Hawkins nb::arg("filename"), nb::arg("line"), nb::arg("col"), 2946b56d1ec6SPeter Hawkins nb::arg("context").none() = nb::none(), 2947b56d1ec6SPeter Hawkins kContextGetFileLocationDocstring) 294804d76d36SJacques Pienaar .def_static( 2949a77250fdSJacques Pienaar "file", 2950a77250fdSJacques Pienaar [](std::string filename, int startLine, int startCol, int endLine, 2951a77250fdSJacques Pienaar int endCol, DefaultingPyMlirContext context) { 2952a77250fdSJacques Pienaar return PyLocation(context->getRef(), 2953a77250fdSJacques Pienaar mlirLocationFileLineColRangeGet( 2954a77250fdSJacques Pienaar context->get(), toMlirStringRef(filename), 2955a77250fdSJacques Pienaar startLine, startCol, endLine, endCol)); 2956a77250fdSJacques Pienaar }, 2957a77250fdSJacques Pienaar nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"), 2958a77250fdSJacques Pienaar nb::arg("end_line"), nb::arg("end_col"), 2959a77250fdSJacques Pienaar nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring) 2960a77250fdSJacques Pienaar .def_static( 29611ab3efacSJacques Pienaar "fused", 29627ee25bc5SStella Laurenzo [](const std::vector<PyLocation> &pyLocations, 29630a81ace0SKazu Hirata std::optional<PyAttribute> metadata, 29641ab3efacSJacques Pienaar DefaultingPyMlirContext context) { 29651ab3efacSJacques Pienaar llvm::SmallVector<MlirLocation, 4> locations; 29661ab3efacSJacques Pienaar locations.reserve(pyLocations.size()); 29671ab3efacSJacques Pienaar for (auto &pyLocation : pyLocations) 29681ab3efacSJacques Pienaar locations.push_back(pyLocation.get()); 29691ab3efacSJacques Pienaar MlirLocation location = mlirLocationFusedGet( 29701ab3efacSJacques Pienaar context->get(), locations.size(), locations.data(), 29711ab3efacSJacques Pienaar metadata ? metadata->get() : MlirAttribute{0}); 29721ab3efacSJacques Pienaar return PyLocation(context->getRef(), location); 29731ab3efacSJacques Pienaar }, 2974b56d1ec6SPeter Hawkins nb::arg("locations"), nb::arg("metadata").none() = nb::none(), 2975b56d1ec6SPeter Hawkins nb::arg("context").none() = nb::none(), 2976b56d1ec6SPeter Hawkins kContextGetFusedLocationDocstring) 29771ab3efacSJacques Pienaar .def_static( 297804d76d36SJacques Pienaar "name", 29790a81ace0SKazu Hirata [](std::string name, std::optional<PyLocation> childLoc, 298004d76d36SJacques Pienaar DefaultingPyMlirContext context) { 298104d76d36SJacques Pienaar return PyLocation( 298204d76d36SJacques Pienaar context->getRef(), 298304d76d36SJacques Pienaar mlirLocationNameGet( 298404d76d36SJacques Pienaar context->get(), toMlirStringRef(name), 298504d76d36SJacques Pienaar childLoc ? childLoc->get() 298604d76d36SJacques Pienaar : mlirLocationUnknownGet(context->get()))); 298704d76d36SJacques Pienaar }, 2988b56d1ec6SPeter Hawkins nb::arg("name"), nb::arg("childLoc").none() = nb::none(), 2989b56d1ec6SPeter Hawkins nb::arg("context").none() = nb::none(), 2990b56d1ec6SPeter Hawkins kContextGetNameLocationDocString) 2991792f3c81SAndrew Young .def_static( 2992792f3c81SAndrew Young "from_attr", 2993792f3c81SAndrew Young [](PyAttribute &attribute, DefaultingPyMlirContext context) { 2994792f3c81SAndrew Young return PyLocation(context->getRef(), 2995792f3c81SAndrew Young mlirLocationFromAttribute(attribute)); 2996792f3c81SAndrew Young }, 2997b56d1ec6SPeter Hawkins nb::arg("attribute"), nb::arg("context").none() = nb::none(), 2998792f3c81SAndrew Young "Gets a Location from a LocationAttr") 2999b56d1ec6SPeter Hawkins .def_prop_ro( 3000436c6c9cSStella Laurenzo "context", 3001436c6c9cSStella Laurenzo [](PyLocation &self) { return self.getContext().getObject(); }, 3002436c6c9cSStella Laurenzo "Context that owns the Location") 3003b56d1ec6SPeter Hawkins .def_prop_ro( 3004792f3c81SAndrew Young "attr", 30059566ee28Smax [](PyLocation &self) { return mlirLocationGetAttribute(self); }, 3006792f3c81SAndrew Young "Get the underlying LocationAttr") 30077ee25bc5SStella Laurenzo .def( 30087ee25bc5SStella Laurenzo "emit_error", 30097ee25bc5SStella Laurenzo [](PyLocation &self, std::string message) { 30107ee25bc5SStella Laurenzo mlirEmitError(self, message.c_str()); 30117ee25bc5SStella Laurenzo }, 3012b56d1ec6SPeter Hawkins nb::arg("message"), "Emits an error at this location") 3013436c6c9cSStella Laurenzo .def("__repr__", [](PyLocation &self) { 3014436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 3015436c6c9cSStella Laurenzo mlirLocationPrint(self, printAccum.getCallback(), 3016436c6c9cSStella Laurenzo printAccum.getUserData()); 3017436c6c9cSStella Laurenzo return printAccum.join(); 3018436c6c9cSStella Laurenzo }); 3019436c6c9cSStella Laurenzo 3020436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3021436c6c9cSStella Laurenzo // Mapping of Module 3022436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3023b56d1ec6SPeter Hawkins nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable()) 3024b56d1ec6SPeter Hawkins .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 3025436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 3026436c6c9cSStella Laurenzo .def_static( 3027436c6c9cSStella Laurenzo "parse", 30283ea4c501SRahul Kayaith [](const std::string &moduleAsm, DefaultingPyMlirContext context) { 30293ea4c501SRahul Kayaith PyMlirContext::ErrorCapture errors(context->getRef()); 3030436c6c9cSStella Laurenzo MlirModule module = mlirModuleCreateParse( 3031436c6c9cSStella Laurenzo context->get(), toMlirStringRef(moduleAsm)); 30323ea4c501SRahul Kayaith if (mlirModuleIsNull(module)) 30333ea4c501SRahul Kayaith throw MLIRError("Unable to parse module assembly", errors.take()); 3034436c6c9cSStella Laurenzo return PyModule::forModule(module).releaseObject(); 3035436c6c9cSStella Laurenzo }, 3036b56d1ec6SPeter Hawkins nb::arg("asm"), nb::arg("context").none() = nb::none(), 3037b56d1ec6SPeter Hawkins kModuleParseDocstring) 3038b56d1ec6SPeter Hawkins .def_static( 3039b56d1ec6SPeter Hawkins "parse", 3040b56d1ec6SPeter Hawkins [](nb::bytes moduleAsm, DefaultingPyMlirContext context) { 3041b56d1ec6SPeter Hawkins PyMlirContext::ErrorCapture errors(context->getRef()); 3042b56d1ec6SPeter Hawkins MlirModule module = mlirModuleCreateParse( 3043b56d1ec6SPeter Hawkins context->get(), toMlirStringRef(moduleAsm)); 3044b56d1ec6SPeter Hawkins if (mlirModuleIsNull(module)) 3045b56d1ec6SPeter Hawkins throw MLIRError("Unable to parse module assembly", errors.take()); 3046b56d1ec6SPeter Hawkins return PyModule::forModule(module).releaseObject(); 3047b56d1ec6SPeter Hawkins }, 3048b56d1ec6SPeter Hawkins nb::arg("asm"), nb::arg("context").none() = nb::none(), 3049436c6c9cSStella Laurenzo kModuleParseDocstring) 3050436c6c9cSStella Laurenzo .def_static( 3051436c6c9cSStella Laurenzo "create", 3052436c6c9cSStella Laurenzo [](DefaultingPyLocation loc) { 3053436c6c9cSStella Laurenzo MlirModule module = mlirModuleCreateEmpty(loc); 3054436c6c9cSStella Laurenzo return PyModule::forModule(module).releaseObject(); 3055436c6c9cSStella Laurenzo }, 3056b56d1ec6SPeter Hawkins nb::arg("loc").none() = nb::none(), "Creates an empty module") 3057b56d1ec6SPeter Hawkins .def_prop_ro( 3058436c6c9cSStella Laurenzo "context", 3059436c6c9cSStella Laurenzo [](PyModule &self) { return self.getContext().getObject(); }, 3060436c6c9cSStella Laurenzo "Context that created the Module") 3061b56d1ec6SPeter Hawkins .def_prop_ro( 3062436c6c9cSStella Laurenzo "operation", 3063436c6c9cSStella Laurenzo [](PyModule &self) { 3064436c6c9cSStella Laurenzo return PyOperation::forOperation(self.getContext(), 3065436c6c9cSStella Laurenzo mlirModuleGetOperation(self.get()), 3066436c6c9cSStella Laurenzo self.getRef().releaseObject()) 3067436c6c9cSStella Laurenzo .releaseObject(); 3068436c6c9cSStella Laurenzo }, 3069436c6c9cSStella Laurenzo "Accesses the module as an operation") 3070b56d1ec6SPeter Hawkins .def_prop_ro( 3071436c6c9cSStella Laurenzo "body", 3072436c6c9cSStella Laurenzo [](PyModule &self) { 307302b6fb21SMehdi Amini PyOperationRef moduleOp = PyOperation::forOperation( 3074436c6c9cSStella Laurenzo self.getContext(), mlirModuleGetOperation(self.get()), 3075436c6c9cSStella Laurenzo self.getRef().releaseObject()); 307602b6fb21SMehdi Amini PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); 3077436c6c9cSStella Laurenzo return returnBlock; 3078436c6c9cSStella Laurenzo }, 3079436c6c9cSStella Laurenzo "Return the block for this module") 3080436c6c9cSStella Laurenzo .def( 3081436c6c9cSStella Laurenzo "dump", 3082436c6c9cSStella Laurenzo [](PyModule &self) { 3083436c6c9cSStella Laurenzo mlirOperationDump(mlirModuleGetOperation(self.get())); 3084436c6c9cSStella Laurenzo }, 3085436c6c9cSStella Laurenzo kDumpDocstring) 3086436c6c9cSStella Laurenzo .def( 3087436c6c9cSStella Laurenzo "__str__", 3088b56d1ec6SPeter Hawkins [](nb::object self) { 3089ace1d0adSStella Laurenzo // Defer to the operation's __str__. 3090ace1d0adSStella Laurenzo return self.attr("operation").attr("__str__")(); 3091436c6c9cSStella Laurenzo }, 3092436c6c9cSStella Laurenzo kOperationStrDunderDocstring); 3093436c6c9cSStella Laurenzo 3094436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3095436c6c9cSStella Laurenzo // Mapping of Operation. 3096436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3097b56d1ec6SPeter Hawkins nb::class_<PyOperationBase>(m, "_OperationBase") 3098b56d1ec6SPeter Hawkins .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, 30991fb2e842SStella Laurenzo [](PyOperationBase &self) { 31001fb2e842SStella Laurenzo return self.getOperation().getCapsule(); 31011fb2e842SStella Laurenzo }) 3102436c6c9cSStella Laurenzo .def("__eq__", 3103436c6c9cSStella Laurenzo [](PyOperationBase &self, PyOperationBase &other) { 3104436c6c9cSStella Laurenzo return &self.getOperation() == &other.getOperation(); 3105436c6c9cSStella Laurenzo }) 3106436c6c9cSStella Laurenzo .def("__eq__", 3107b56d1ec6SPeter Hawkins [](PyOperationBase &self, nb::object other) { return false; }) 3108f78fe0b7Srkayaith .def("__hash__", 3109f78fe0b7Srkayaith [](PyOperationBase &self) { 3110f78fe0b7Srkayaith return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 3111f78fe0b7Srkayaith }) 3112b56d1ec6SPeter Hawkins .def_prop_ro("attributes", 3113436c6c9cSStella Laurenzo [](PyOperationBase &self) { 3114b56d1ec6SPeter Hawkins return PyOpAttributeMap(self.getOperation().getRef()); 3115436c6c9cSStella Laurenzo }) 3116b56d1ec6SPeter Hawkins .def_prop_ro( 311733df617dSStella Laurenzo "context", 311833df617dSStella Laurenzo [](PyOperationBase &self) { 311933df617dSStella Laurenzo PyOperation &concreteOperation = self.getOperation(); 312033df617dSStella Laurenzo concreteOperation.checkValid(); 312133df617dSStella Laurenzo return concreteOperation.getContext().getObject(); 312233df617dSStella Laurenzo }, 312333df617dSStella Laurenzo "Context that owns the Operation") 3124b56d1ec6SPeter Hawkins .def_prop_ro("name", 312533df617dSStella Laurenzo [](PyOperationBase &self) { 312633df617dSStella Laurenzo auto &concreteOperation = self.getOperation(); 312733df617dSStella Laurenzo concreteOperation.checkValid(); 3128b56d1ec6SPeter Hawkins MlirOperation operation = concreteOperation.get(); 3129b56d1ec6SPeter Hawkins MlirStringRef name = 3130b56d1ec6SPeter Hawkins mlirIdentifierStr(mlirOperationGetName(operation)); 3131b56d1ec6SPeter Hawkins return nb::str(name.data, name.length); 313233df617dSStella Laurenzo }) 3133b56d1ec6SPeter Hawkins .def_prop_ro("operands", 3134436c6c9cSStella Laurenzo [](PyOperationBase &self) { 3135b56d1ec6SPeter Hawkins return PyOpOperandList(self.getOperation().getRef()); 3136436c6c9cSStella Laurenzo }) 3137b56d1ec6SPeter Hawkins .def_prop_ro("regions", 3138436c6c9cSStella Laurenzo [](PyOperationBase &self) { 3139b56d1ec6SPeter Hawkins return PyRegionList(self.getOperation().getRef()); 3140436c6c9cSStella Laurenzo }) 3141b56d1ec6SPeter Hawkins .def_prop_ro( 3142436c6c9cSStella Laurenzo "results", 3143436c6c9cSStella Laurenzo [](PyOperationBase &self) { 3144436c6c9cSStella Laurenzo return PyOpResultList(self.getOperation().getRef()); 3145436c6c9cSStella Laurenzo }, 3146436c6c9cSStella Laurenzo "Returns the list of Operation results.") 3147b56d1ec6SPeter Hawkins .def_prop_ro( 3148436c6c9cSStella Laurenzo "result", 3149436c6c9cSStella Laurenzo [](PyOperationBase &self) { 3150436c6c9cSStella Laurenzo auto &operation = self.getOperation(); 3151*acde3f72SPeter Hawkins return PyOpResult(operation.getRef(), getUniqueResult(operation)) 31527c850867SMaksim Levental .maybeDownCast(); 3153436c6c9cSStella Laurenzo }, 3154436c6c9cSStella Laurenzo "Shortcut to get an op result if it has only one (throws an error " 3155436c6c9cSStella Laurenzo "otherwise).") 3156b56d1ec6SPeter Hawkins .def_prop_ro( 3157d5429a13Srkayaith "location", 3158d5429a13Srkayaith [](PyOperationBase &self) { 3159d5429a13Srkayaith PyOperation &operation = self.getOperation(); 3160d5429a13Srkayaith return PyLocation(operation.getContext(), 3161d5429a13Srkayaith mlirOperationGetLocation(operation.get())); 3162d5429a13Srkayaith }, 3163d5429a13Srkayaith "Returns the source location the operation was defined or derived " 3164d5429a13Srkayaith "from.") 3165b56d1ec6SPeter Hawkins .def_prop_ro("parent", 3166b56d1ec6SPeter Hawkins [](PyOperationBase &self) -> nb::object { 3167b56d1ec6SPeter Hawkins auto parent = self.getOperation().getParentOperation(); 316833df617dSStella Laurenzo if (parent) 316933df617dSStella Laurenzo return parent->getObject(); 3170b56d1ec6SPeter Hawkins return nb::none(); 317133df617dSStella Laurenzo }) 3172436c6c9cSStella Laurenzo .def( 3173436c6c9cSStella Laurenzo "__str__", 3174436c6c9cSStella Laurenzo [](PyOperationBase &self) { 3175436c6c9cSStella Laurenzo return self.getAsm(/*binary=*/false, 3176e823ababSKazu Hirata /*largeElementsLimit=*/std::nullopt, 3177436c6c9cSStella Laurenzo /*enableDebugInfo=*/false, 3178436c6c9cSStella Laurenzo /*prettyDebugInfo=*/false, 3179436c6c9cSStella Laurenzo /*printGenericOpForm=*/false, 3180ace1d0adSStella Laurenzo /*useLocalScope=*/false, 3181abad8455SJonas Rickert /*assumeVerified=*/false, 3182abad8455SJonas Rickert /*skipRegions=*/false); 3183436c6c9cSStella Laurenzo }, 3184436c6c9cSStella Laurenzo "Returns the assembly form of the operation.") 3185204acc5cSJacques Pienaar .def("print", 3186b56d1ec6SPeter Hawkins nb::overload_cast<PyAsmState &, nb::object, bool>( 3187204acc5cSJacques Pienaar &PyOperationBase::print), 3188b56d1ec6SPeter Hawkins nb::arg("state"), nb::arg("file").none() = nb::none(), 3189b56d1ec6SPeter Hawkins nb::arg("binary") = false, kOperationPrintStateDocstring) 3190204acc5cSJacques Pienaar .def("print", 3191b56d1ec6SPeter Hawkins nb::overload_cast<std::optional<int64_t>, bool, bool, bool, bool, 3192b56d1ec6SPeter Hawkins bool, nb::object, bool, bool>( 3193abad8455SJonas Rickert &PyOperationBase::print), 3194436c6c9cSStella Laurenzo // Careful: Lots of arguments must match up with print method. 3195b56d1ec6SPeter Hawkins nb::arg("large_elements_limit").none() = nb::none(), 3196b56d1ec6SPeter Hawkins nb::arg("enable_debug_info") = false, 3197b56d1ec6SPeter Hawkins nb::arg("pretty_debug_info") = false, 3198b56d1ec6SPeter Hawkins nb::arg("print_generic_op_form") = false, 3199b56d1ec6SPeter Hawkins nb::arg("use_local_scope") = false, 3200b56d1ec6SPeter Hawkins nb::arg("assume_verified") = false, 3201b56d1ec6SPeter Hawkins nb::arg("file").none() = nb::none(), nb::arg("binary") = false, 3202b56d1ec6SPeter Hawkins nb::arg("skip_regions") = false, kOperationPrintDocstring) 3203b56d1ec6SPeter Hawkins .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"), 3204b56d1ec6SPeter Hawkins nb::arg("desired_version").none() = nb::none(), 320589418ddcSMehdi Amini kOperationPrintBytecodeDocstring) 3206436c6c9cSStella Laurenzo .def("get_asm", &PyOperationBase::getAsm, 3207436c6c9cSStella Laurenzo // Careful: Lots of arguments must match up with get_asm method. 3208b56d1ec6SPeter Hawkins nb::arg("binary") = false, 3209b56d1ec6SPeter Hawkins nb::arg("large_elements_limit").none() = nb::none(), 3210b56d1ec6SPeter Hawkins nb::arg("enable_debug_info") = false, 3211b56d1ec6SPeter Hawkins nb::arg("pretty_debug_info") = false, 3212b56d1ec6SPeter Hawkins nb::arg("print_generic_op_form") = false, 3213b56d1ec6SPeter Hawkins nb::arg("use_local_scope") = false, 3214b56d1ec6SPeter Hawkins nb::arg("assume_verified") = false, nb::arg("skip_regions") = false, 3215abad8455SJonas Rickert kOperationGetAsmDocstring) 32163ea4c501SRahul Kayaith .def("verify", &PyOperationBase::verify, 32173ea4c501SRahul Kayaith "Verify the operation. Raises MLIRError if verification fails, and " 32183ea4c501SRahul Kayaith "returns true otherwise.") 3219b56d1ec6SPeter Hawkins .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"), 322024685aaeSAlex Zinenko "Puts self immediately after the other operation in its parent " 322124685aaeSAlex Zinenko "block.") 3222b56d1ec6SPeter Hawkins .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"), 322324685aaeSAlex Zinenko "Puts self immediately before the other operation in its parent " 322424685aaeSAlex Zinenko "block.") 322524685aaeSAlex Zinenko .def( 322633df617dSStella Laurenzo "clone", 3227b56d1ec6SPeter Hawkins [](PyOperationBase &self, nb::object ip) { 322833df617dSStella Laurenzo return self.getOperation().clone(ip); 322933df617dSStella Laurenzo }, 3230b56d1ec6SPeter Hawkins nb::arg("ip").none() = nb::none()) 323133df617dSStella Laurenzo .def( 323224685aaeSAlex Zinenko "detach_from_parent", 323324685aaeSAlex Zinenko [](PyOperationBase &self) { 323424685aaeSAlex Zinenko PyOperation &operation = self.getOperation(); 323524685aaeSAlex Zinenko operation.checkValid(); 323624685aaeSAlex Zinenko if (!operation.isAttached()) 3237b56d1ec6SPeter Hawkins throw nb::value_error("Detached operation has no parent."); 323824685aaeSAlex Zinenko 323924685aaeSAlex Zinenko operation.detachFromParent(); 324024685aaeSAlex Zinenko return operation.createOpView(); 324124685aaeSAlex Zinenko }, 324233df617dSStella Laurenzo "Detaches the operation from its parent block.") 324347148832SHideto Ueno .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) 3244b56d1ec6SPeter Hawkins .def("walk", &PyOperationBase::walk, nb::arg("callback"), 3245b56d1ec6SPeter Hawkins nb::arg("walk_order") = MlirWalkPostOrder); 3246436c6c9cSStella Laurenzo 3247b56d1ec6SPeter Hawkins nb::class_<PyOperation, PyOperationBase>(m, "Operation") 3248*acde3f72SPeter Hawkins .def_static( 3249*acde3f72SPeter Hawkins "create", 3250*acde3f72SPeter Hawkins [](std::string_view name, 3251*acde3f72SPeter Hawkins std::optional<std::vector<PyType *>> results, 3252*acde3f72SPeter Hawkins std::optional<std::vector<PyValue *>> operands, 3253*acde3f72SPeter Hawkins std::optional<nb::dict> attributes, 3254*acde3f72SPeter Hawkins std::optional<std::vector<PyBlock *>> successors, int regions, 3255*acde3f72SPeter Hawkins DefaultingPyLocation location, const nb::object &maybeIp, 3256*acde3f72SPeter Hawkins bool inferType) { 3257*acde3f72SPeter Hawkins // Unpack/validate operands. 3258*acde3f72SPeter Hawkins llvm::SmallVector<MlirValue, 4> mlirOperands; 3259*acde3f72SPeter Hawkins if (operands) { 3260*acde3f72SPeter Hawkins mlirOperands.reserve(operands->size()); 3261*acde3f72SPeter Hawkins for (PyValue *operand : *operands) { 3262*acde3f72SPeter Hawkins if (!operand) 3263*acde3f72SPeter Hawkins throw nb::value_error("operand value cannot be None"); 3264*acde3f72SPeter Hawkins mlirOperands.push_back(operand->get()); 3265*acde3f72SPeter Hawkins } 3266*acde3f72SPeter Hawkins } 3267*acde3f72SPeter Hawkins 3268*acde3f72SPeter Hawkins return PyOperation::create(name, results, mlirOperands, attributes, 3269*acde3f72SPeter Hawkins successors, regions, location, maybeIp, 3270*acde3f72SPeter Hawkins inferType); 3271*acde3f72SPeter Hawkins }, 3272*acde3f72SPeter Hawkins nb::arg("name"), nb::arg("results").none() = nb::none(), 3273b56d1ec6SPeter Hawkins nb::arg("operands").none() = nb::none(), 3274b56d1ec6SPeter Hawkins nb::arg("attributes").none() = nb::none(), 3275*acde3f72SPeter Hawkins nb::arg("successors").none() = nb::none(), nb::arg("regions") = 0, 3276*acde3f72SPeter Hawkins nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(), 3277b56d1ec6SPeter Hawkins nb::arg("infer_type") = false, kOperationCreateDocstring) 327837107e17Srkayaith .def_static( 327937107e17Srkayaith "parse", 328037107e17Srkayaith [](const std::string &sourceStr, const std::string &sourceName, 328137107e17Srkayaith DefaultingPyMlirContext context) { 328237107e17Srkayaith return PyOperation::parse(context->getRef(), sourceStr, sourceName) 328337107e17Srkayaith ->createOpView(); 328437107e17Srkayaith }, 3285b56d1ec6SPeter Hawkins nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "", 3286b56d1ec6SPeter Hawkins nb::arg("context").none() = nb::none(), 328737107e17Srkayaith "Parses an operation. Supports both text assembly format and binary " 328837107e17Srkayaith "bytecode format.") 3289b56d1ec6SPeter Hawkins .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) 32900126e906SJohn Demme .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 3291b56d1ec6SPeter Hawkins .def_prop_ro("operation", [](nb::object self) { return self; }) 3292b56d1ec6SPeter Hawkins .def_prop_ro("opview", &PyOperation::createOpView) 3293b56d1ec6SPeter Hawkins .def_prop_ro( 3294d7e49736SMaksim Levental "successors", 3295d7e49736SMaksim Levental [](PyOperationBase &self) { 3296d7e49736SMaksim Levental return PyOpSuccessors(self.getOperation().getRef()); 3297d7e49736SMaksim Levental }, 3298d7e49736SMaksim Levental "Returns the list of Operation successors."); 3299436c6c9cSStella Laurenzo 3300436c6c9cSStella Laurenzo auto opViewClass = 3301b56d1ec6SPeter Hawkins nb::class_<PyOpView, PyOperationBase>(m, "OpView") 3302b56d1ec6SPeter Hawkins .def(nb::init<nb::object>(), nb::arg("operation")) 3303f4125e02SPeter Hawkins .def( 3304f4125e02SPeter Hawkins "__init__", 3305f4125e02SPeter Hawkins [](PyOpView *self, std::string_view name, 3306f4125e02SPeter Hawkins std::tuple<int, bool> opRegionSpec, 3307f4125e02SPeter Hawkins nb::object operandSegmentSpecObj, 3308f4125e02SPeter Hawkins nb::object resultSegmentSpecObj, 3309f4125e02SPeter Hawkins std::optional<nb::list> resultTypeList, nb::list operandList, 3310f4125e02SPeter Hawkins std::optional<nb::dict> attributes, 3311f4125e02SPeter Hawkins std::optional<std::vector<PyBlock *>> successors, 3312f4125e02SPeter Hawkins std::optional<int> regions, DefaultingPyLocation location, 3313f4125e02SPeter Hawkins const nb::object &maybeIp) { 3314f4125e02SPeter Hawkins new (self) PyOpView(PyOpView::buildGeneric( 3315f4125e02SPeter Hawkins name, opRegionSpec, operandSegmentSpecObj, 3316f4125e02SPeter Hawkins resultSegmentSpecObj, resultTypeList, operandList, 3317f4125e02SPeter Hawkins attributes, successors, regions, location, maybeIp)); 3318f4125e02SPeter Hawkins }, 3319f4125e02SPeter Hawkins nb::arg("name"), nb::arg("opRegionSpec"), 3320f4125e02SPeter Hawkins nb::arg("operandSegmentSpecObj").none() = nb::none(), 3321f4125e02SPeter Hawkins nb::arg("resultSegmentSpecObj").none() = nb::none(), 3322f4125e02SPeter Hawkins nb::arg("results").none() = nb::none(), 3323f4125e02SPeter Hawkins nb::arg("operands").none() = nb::none(), 3324f4125e02SPeter Hawkins nb::arg("attributes").none() = nb::none(), 3325f4125e02SPeter Hawkins nb::arg("successors").none() = nb::none(), 3326f4125e02SPeter Hawkins nb::arg("regions").none() = nb::none(), 3327f4125e02SPeter Hawkins nb::arg("loc").none() = nb::none(), 3328f4125e02SPeter Hawkins nb::arg("ip").none() = nb::none()) 3329f4125e02SPeter Hawkins 3330b56d1ec6SPeter Hawkins .def_prop_ro("operation", &PyOpView::getOperationObject) 3331b56d1ec6SPeter Hawkins .def_prop_ro("opview", [](nb::object self) { return self; }) 3332d7e49736SMaksim Levental .def( 3333d7e49736SMaksim Levental "__str__", 3334b56d1ec6SPeter Hawkins [](PyOpView &self) { return nb::str(self.getOperationObject()); }) 3335b56d1ec6SPeter Hawkins .def_prop_ro( 3336d7e49736SMaksim Levental "successors", 3337d7e49736SMaksim Levental [](PyOperationBase &self) { 3338d7e49736SMaksim Levental return PyOpSuccessors(self.getOperation().getRef()); 3339d7e49736SMaksim Levental }, 3340d7e49736SMaksim Levental "Returns the list of Operation successors."); 3341b56d1ec6SPeter Hawkins opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true); 3342b56d1ec6SPeter Hawkins opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none(); 3343b56d1ec6SPeter Hawkins opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none(); 3344f4125e02SPeter Hawkins // It is faster to pass the operation_name, ods_regions, and 3345f4125e02SPeter Hawkins // ods_operand_segments/ods_result_segments as arguments to the constructor, 3346f4125e02SPeter Hawkins // rather than to access them as attributes. 3347436c6c9cSStella Laurenzo opViewClass.attr("build_generic") = classmethod( 3348f4125e02SPeter Hawkins [](nb::handle cls, std::optional<nb::list> resultTypeList, 3349f4125e02SPeter Hawkins nb::list operandList, std::optional<nb::dict> attributes, 3350f4125e02SPeter Hawkins std::optional<std::vector<PyBlock *>> successors, 3351f4125e02SPeter Hawkins std::optional<int> regions, DefaultingPyLocation location, 3352f4125e02SPeter Hawkins const nb::object &maybeIp) { 3353f4125e02SPeter Hawkins std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME")); 3354f4125e02SPeter Hawkins std::tuple<int, bool> opRegionSpec = 3355f4125e02SPeter Hawkins nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 3356f4125e02SPeter Hawkins nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS"); 3357f4125e02SPeter Hawkins nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS"); 3358f4125e02SPeter Hawkins return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec, 3359f4125e02SPeter Hawkins resultSegmentSpec, resultTypeList, 3360f4125e02SPeter Hawkins operandList, attributes, successors, 3361f4125e02SPeter Hawkins regions, location, maybeIp); 3362f4125e02SPeter Hawkins }, 3363f4125e02SPeter Hawkins nb::arg("cls"), nb::arg("results").none() = nb::none(), 3364b56d1ec6SPeter Hawkins nb::arg("operands").none() = nb::none(), 3365b56d1ec6SPeter Hawkins nb::arg("attributes").none() = nb::none(), 3366b56d1ec6SPeter Hawkins nb::arg("successors").none() = nb::none(), 3367b56d1ec6SPeter Hawkins nb::arg("regions").none() = nb::none(), 3368b56d1ec6SPeter Hawkins nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(), 3369436c6c9cSStella Laurenzo "Builds a specific, generated OpView based on class level attributes."); 337037107e17Srkayaith opViewClass.attr("parse") = classmethod( 3371b56d1ec6SPeter Hawkins [](const nb::object &cls, const std::string &sourceStr, 337237107e17Srkayaith const std::string &sourceName, DefaultingPyMlirContext context) { 337337107e17Srkayaith PyOperationRef parsed = 337437107e17Srkayaith PyOperation::parse(context->getRef(), sourceStr, sourceName); 337537107e17Srkayaith 337637107e17Srkayaith // Check if the expected operation was parsed, and cast to to the 337737107e17Srkayaith // appropriate `OpView` subclass if successful. 337837107e17Srkayaith // NOTE: This accesses attributes that have been automatically added to 337937107e17Srkayaith // `OpView` subclasses, and is not intended to be used on `OpView` 338037107e17Srkayaith // directly. 338137107e17Srkayaith std::string clsOpName = 3382b56d1ec6SPeter Hawkins nb::cast<std::string>(cls.attr("OPERATION_NAME")); 33833ea4c501SRahul Kayaith MlirStringRef identifier = 338437107e17Srkayaith mlirIdentifierStr(mlirOperationGetName(*parsed.get())); 33853ea4c501SRahul Kayaith std::string_view parsedOpName(identifier.data, identifier.length); 33863ea4c501SRahul Kayaith if (clsOpName != parsedOpName) 33873ea4c501SRahul Kayaith throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" + 33883ea4c501SRahul Kayaith parsedOpName + "'"); 3389b56d1ec6SPeter Hawkins return PyOpView::constructDerived(cls, parsed.getObject()); 339037107e17Srkayaith }, 3391b56d1ec6SPeter Hawkins nb::arg("cls"), nb::arg("source"), nb::kw_only(), 3392b56d1ec6SPeter Hawkins nb::arg("source_name") = "", nb::arg("context").none() = nb::none(), 339337107e17Srkayaith "Parses a specific, generated OpView based on class level attributes"); 3394436c6c9cSStella Laurenzo 3395436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3396436c6c9cSStella Laurenzo // Mapping of PyRegion. 3397436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3398b56d1ec6SPeter Hawkins nb::class_<PyRegion>(m, "Region") 3399b56d1ec6SPeter Hawkins .def_prop_ro( 3400436c6c9cSStella Laurenzo "blocks", 3401436c6c9cSStella Laurenzo [](PyRegion &self) { 3402436c6c9cSStella Laurenzo return PyBlockList(self.getParentOperation(), self.get()); 3403436c6c9cSStella Laurenzo }, 3404436c6c9cSStella Laurenzo "Returns a forward-optimized sequence of blocks.") 3405b56d1ec6SPeter Hawkins .def_prop_ro( 340678f2dae0SAlex Zinenko "owner", 340778f2dae0SAlex Zinenko [](PyRegion &self) { 340878f2dae0SAlex Zinenko return self.getParentOperation()->createOpView(); 340978f2dae0SAlex Zinenko }, 341078f2dae0SAlex Zinenko "Returns the operation owning this region.") 3411436c6c9cSStella Laurenzo .def( 3412436c6c9cSStella Laurenzo "__iter__", 3413436c6c9cSStella Laurenzo [](PyRegion &self) { 3414436c6c9cSStella Laurenzo self.checkValid(); 3415436c6c9cSStella Laurenzo MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 3416436c6c9cSStella Laurenzo return PyBlockIterator(self.getParentOperation(), firstBlock); 3417436c6c9cSStella Laurenzo }, 3418436c6c9cSStella Laurenzo "Iterates over blocks in the region.") 3419436c6c9cSStella Laurenzo .def("__eq__", 3420436c6c9cSStella Laurenzo [](PyRegion &self, PyRegion &other) { 3421436c6c9cSStella Laurenzo return self.get().ptr == other.get().ptr; 3422436c6c9cSStella Laurenzo }) 3423b56d1ec6SPeter Hawkins .def("__eq__", [](PyRegion &self, nb::object &other) { return false; }); 3424436c6c9cSStella Laurenzo 3425436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3426436c6c9cSStella Laurenzo // Mapping of PyBlock. 3427436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3428b56d1ec6SPeter Hawkins nb::class_<PyBlock>(m, "Block") 3429b56d1ec6SPeter Hawkins .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) 3430b56d1ec6SPeter Hawkins .def_prop_ro( 343196fbd5cdSJohn Demme "owner", 343296fbd5cdSJohn Demme [](PyBlock &self) { 343396fbd5cdSJohn Demme return self.getParentOperation()->createOpView(); 343496fbd5cdSJohn Demme }, 343596fbd5cdSJohn Demme "Returns the owning operation of this block.") 3436b56d1ec6SPeter Hawkins .def_prop_ro( 34378e6c55c9SStella Laurenzo "region", 34388e6c55c9SStella Laurenzo [](PyBlock &self) { 34398e6c55c9SStella Laurenzo MlirRegion region = mlirBlockGetParentRegion(self.get()); 34408e6c55c9SStella Laurenzo return PyRegion(self.getParentOperation(), region); 34418e6c55c9SStella Laurenzo }, 34428e6c55c9SStella Laurenzo "Returns the owning region of this block.") 3443b56d1ec6SPeter Hawkins .def_prop_ro( 3444436c6c9cSStella Laurenzo "arguments", 3445436c6c9cSStella Laurenzo [](PyBlock &self) { 3446436c6c9cSStella Laurenzo return PyBlockArgumentList(self.getParentOperation(), self.get()); 3447436c6c9cSStella Laurenzo }, 3448436c6c9cSStella Laurenzo "Returns a list of block arguments.") 344955d2fffdSSandeep Dasgupta .def( 345055d2fffdSSandeep Dasgupta "add_argument", 345155d2fffdSSandeep Dasgupta [](PyBlock &self, const PyType &type, const PyLocation &loc) { 345255d2fffdSSandeep Dasgupta return mlirBlockAddArgument(self.get(), type, loc); 345355d2fffdSSandeep Dasgupta }, 345455d2fffdSSandeep Dasgupta "Append an argument of the specified type to the block and returns " 345555d2fffdSSandeep Dasgupta "the newly added argument.") 345655d2fffdSSandeep Dasgupta .def( 345755d2fffdSSandeep Dasgupta "erase_argument", 345855d2fffdSSandeep Dasgupta [](PyBlock &self, unsigned index) { 345955d2fffdSSandeep Dasgupta return mlirBlockEraseArgument(self.get(), index); 346055d2fffdSSandeep Dasgupta }, 346155d2fffdSSandeep Dasgupta "Erase the argument at 'index' and remove it from the argument list.") 3462b56d1ec6SPeter Hawkins .def_prop_ro( 3463436c6c9cSStella Laurenzo "operations", 3464436c6c9cSStella Laurenzo [](PyBlock &self) { 3465436c6c9cSStella Laurenzo return PyOperationList(self.getParentOperation(), self.get()); 3466436c6c9cSStella Laurenzo }, 3467436c6c9cSStella Laurenzo "Returns a forward-optimized sequence of operations.") 346878f2dae0SAlex Zinenko .def_static( 346978f2dae0SAlex Zinenko "create_at_start", 3470b56d1ec6SPeter Hawkins [](PyRegion &parent, const nb::sequence &pyArgTypes, 3471b56d1ec6SPeter Hawkins const std::optional<nb::sequence> &pyArgLocs) { 347278f2dae0SAlex Zinenko parent.checkValid(); 3473514dddbeSRahul Kayaith MlirBlock block = createBlock(pyArgTypes, pyArgLocs); 347478f2dae0SAlex Zinenko mlirRegionInsertOwnedBlock(parent, 0, block); 347578f2dae0SAlex Zinenko return PyBlock(parent.getParentOperation(), block); 347678f2dae0SAlex Zinenko }, 3477b56d1ec6SPeter Hawkins nb::arg("parent"), nb::arg("arg_types") = nb::list(), 3478b56d1ec6SPeter Hawkins nb::arg("arg_locs") = std::nullopt, 347978f2dae0SAlex Zinenko "Creates and returns a new Block at the beginning of the given " 3480514dddbeSRahul Kayaith "region (with given argument types and locations).") 3481436c6c9cSStella Laurenzo .def( 34828d8738f6SJohn Demme "append_to", 34838d8738f6SJohn Demme [](PyBlock &self, PyRegion ®ion) { 34848d8738f6SJohn Demme MlirBlock b = self.get(); 34858d8738f6SJohn Demme if (!mlirRegionIsNull(mlirBlockGetParentRegion(b))) 34868d8738f6SJohn Demme mlirBlockDetach(b); 34878d8738f6SJohn Demme mlirRegionAppendOwnedBlock(region.get(), b); 34888d8738f6SJohn Demme }, 34898d8738f6SJohn Demme "Append this block to a region, transferring ownership if necessary") 34908d8738f6SJohn Demme .def( 34918e6c55c9SStella Laurenzo "create_before", 3492b56d1ec6SPeter Hawkins [](PyBlock &self, const nb::args &pyArgTypes, 3493b56d1ec6SPeter Hawkins const std::optional<nb::sequence> &pyArgLocs) { 34948e6c55c9SStella Laurenzo self.checkValid(); 3495b56d1ec6SPeter Hawkins MlirBlock block = 3496b56d1ec6SPeter Hawkins createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs); 34978e6c55c9SStella Laurenzo MlirRegion region = mlirBlockGetParentRegion(self.get()); 34988e6c55c9SStella Laurenzo mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 34998e6c55c9SStella Laurenzo return PyBlock(self.getParentOperation(), block); 35008e6c55c9SStella Laurenzo }, 3501b56d1ec6SPeter Hawkins nb::arg("arg_types"), nb::kw_only(), 3502b56d1ec6SPeter Hawkins nb::arg("arg_locs") = std::nullopt, 35038e6c55c9SStella Laurenzo "Creates and returns a new Block before this block " 3504514dddbeSRahul Kayaith "(with given argument types and locations).") 35058e6c55c9SStella Laurenzo .def( 35068e6c55c9SStella Laurenzo "create_after", 3507b56d1ec6SPeter Hawkins [](PyBlock &self, const nb::args &pyArgTypes, 3508b56d1ec6SPeter Hawkins const std::optional<nb::sequence> &pyArgLocs) { 35098e6c55c9SStella Laurenzo self.checkValid(); 3510b56d1ec6SPeter Hawkins MlirBlock block = 3511b56d1ec6SPeter Hawkins createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs); 35128e6c55c9SStella Laurenzo MlirRegion region = mlirBlockGetParentRegion(self.get()); 35138e6c55c9SStella Laurenzo mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 35148e6c55c9SStella Laurenzo return PyBlock(self.getParentOperation(), block); 35158e6c55c9SStella Laurenzo }, 3516b56d1ec6SPeter Hawkins nb::arg("arg_types"), nb::kw_only(), 3517b56d1ec6SPeter Hawkins nb::arg("arg_locs") = std::nullopt, 35188e6c55c9SStella Laurenzo "Creates and returns a new Block after this block " 3519514dddbeSRahul Kayaith "(with given argument types and locations).") 35208e6c55c9SStella Laurenzo .def( 3521436c6c9cSStella Laurenzo "__iter__", 3522436c6c9cSStella Laurenzo [](PyBlock &self) { 3523436c6c9cSStella Laurenzo self.checkValid(); 3524436c6c9cSStella Laurenzo MlirOperation firstOperation = 3525436c6c9cSStella Laurenzo mlirBlockGetFirstOperation(self.get()); 3526436c6c9cSStella Laurenzo return PyOperationIterator(self.getParentOperation(), 3527436c6c9cSStella Laurenzo firstOperation); 3528436c6c9cSStella Laurenzo }, 3529436c6c9cSStella Laurenzo "Iterates over operations in the block.") 3530436c6c9cSStella Laurenzo .def("__eq__", 3531436c6c9cSStella Laurenzo [](PyBlock &self, PyBlock &other) { 3532436c6c9cSStella Laurenzo return self.get().ptr == other.get().ptr; 3533436c6c9cSStella Laurenzo }) 3534b56d1ec6SPeter Hawkins .def("__eq__", [](PyBlock &self, nb::object &other) { return false; }) 3535fa45b2fbSMike Urbach .def("__hash__", 3536fa45b2fbSMike Urbach [](PyBlock &self) { 3537fa45b2fbSMike Urbach return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3538fa45b2fbSMike Urbach }) 3539436c6c9cSStella Laurenzo .def( 3540436c6c9cSStella Laurenzo "__str__", 3541436c6c9cSStella Laurenzo [](PyBlock &self) { 3542436c6c9cSStella Laurenzo self.checkValid(); 3543436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 3544436c6c9cSStella Laurenzo mlirBlockPrint(self.get(), printAccum.getCallback(), 3545436c6c9cSStella Laurenzo printAccum.getUserData()); 3546436c6c9cSStella Laurenzo return printAccum.join(); 3547436c6c9cSStella Laurenzo }, 354824685aaeSAlex Zinenko "Returns the assembly form of the block.") 354924685aaeSAlex Zinenko .def( 355024685aaeSAlex Zinenko "append", 355124685aaeSAlex Zinenko [](PyBlock &self, PyOperationBase &operation) { 355224685aaeSAlex Zinenko if (operation.getOperation().isAttached()) 355324685aaeSAlex Zinenko operation.getOperation().detachFromParent(); 355424685aaeSAlex Zinenko 355524685aaeSAlex Zinenko MlirOperation mlirOperation = operation.getOperation().get(); 355624685aaeSAlex Zinenko mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 355724685aaeSAlex Zinenko operation.getOperation().setAttached( 355824685aaeSAlex Zinenko self.getParentOperation().getObject()); 355924685aaeSAlex Zinenko }, 3560b56d1ec6SPeter Hawkins nb::arg("operation"), 356124685aaeSAlex Zinenko "Appends an operation to this block. If the operation is currently " 356224685aaeSAlex Zinenko "in another block, it will be moved."); 3563436c6c9cSStella Laurenzo 3564436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3565436c6c9cSStella Laurenzo // Mapping of PyInsertionPoint. 3566436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3567436c6c9cSStella Laurenzo 3568b56d1ec6SPeter Hawkins nb::class_<PyInsertionPoint>(m, "InsertionPoint") 3569b56d1ec6SPeter Hawkins .def(nb::init<PyBlock &>(), nb::arg("block"), 3570436c6c9cSStella Laurenzo "Inserts after the last operation but still inside the block.") 3571436c6c9cSStella Laurenzo .def("__enter__", &PyInsertionPoint::contextEnter) 3572b56d1ec6SPeter Hawkins .def("__exit__", &PyInsertionPoint::contextExit, 3573b56d1ec6SPeter Hawkins nb::arg("exc_type").none(), nb::arg("exc_value").none(), 3574b56d1ec6SPeter Hawkins nb::arg("traceback").none()) 3575b56d1ec6SPeter Hawkins .def_prop_ro_static( 3576436c6c9cSStella Laurenzo "current", 3577b56d1ec6SPeter Hawkins [](nb::object & /*class*/) { 3578436c6c9cSStella Laurenzo auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 3579436c6c9cSStella Laurenzo if (!ip) 3580b56d1ec6SPeter Hawkins throw nb::value_error("No current InsertionPoint"); 3581436c6c9cSStella Laurenzo return ip; 3582436c6c9cSStella Laurenzo }, 3583436c6c9cSStella Laurenzo "Gets the InsertionPoint bound to the current thread or raises " 3584436c6c9cSStella Laurenzo "ValueError if none has been set") 3585b56d1ec6SPeter Hawkins .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"), 3586436c6c9cSStella Laurenzo "Inserts before a referenced operation.") 3587436c6c9cSStella Laurenzo .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 3588b56d1ec6SPeter Hawkins nb::arg("block"), "Inserts at the beginning of the block.") 3589436c6c9cSStella Laurenzo .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 3590b56d1ec6SPeter Hawkins nb::arg("block"), "Inserts before the block terminator.") 3591b56d1ec6SPeter Hawkins .def("insert", &PyInsertionPoint::insert, nb::arg("operation"), 35928e6c55c9SStella Laurenzo "Inserts an operation.") 3593b56d1ec6SPeter Hawkins .def_prop_ro( 35948e6c55c9SStella Laurenzo "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 35955a600c23STomás Longeri "Returns the block that this InsertionPoint points to.") 3596b56d1ec6SPeter Hawkins .def_prop_ro( 35975a600c23STomás Longeri "ref_operation", 3598b56d1ec6SPeter Hawkins [](PyInsertionPoint &self) -> nb::object { 3599b1d682e0SMehdi Amini auto refOperation = self.getRefOperation(); 3600b1d682e0SMehdi Amini if (refOperation) 3601b1d682e0SMehdi Amini return refOperation->getObject(); 3602b56d1ec6SPeter Hawkins return nb::none(); 36035a600c23STomás Longeri }, 36045a600c23STomás Longeri "The reference operation before which new operations are " 36055a600c23STomás Longeri "inserted, or None if the insertion point is at the end of " 36065a600c23STomás Longeri "the block"); 3607436c6c9cSStella Laurenzo 3608436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3609436c6c9cSStella Laurenzo // Mapping of PyAttribute. 3610436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3611b56d1ec6SPeter Hawkins nb::class_<PyAttribute>(m, "Attribute") 3612b57d6fe4SStella Laurenzo // Delegate to the PyAttribute copy constructor, which will also lifetime 3613b57d6fe4SStella Laurenzo // extend the backing context which owns the MlirAttribute. 3614b56d1ec6SPeter Hawkins .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"), 3615b57d6fe4SStella Laurenzo "Casts the passed attribute to the generic Attribute") 3616b56d1ec6SPeter Hawkins .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule) 3617436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 3618436c6c9cSStella Laurenzo .def_static( 3619436c6c9cSStella Laurenzo "parse", 36204eee9ef9Smax [](const std::string &attrSpec, DefaultingPyMlirContext context) { 36213ea4c501SRahul Kayaith PyMlirContext::ErrorCapture errors(context->getRef()); 36224eee9ef9Smax MlirAttribute attr = mlirAttributeParseGet( 3623436c6c9cSStella Laurenzo context->get(), toMlirStringRef(attrSpec)); 36244eee9ef9Smax if (mlirAttributeIsNull(attr)) 36253ea4c501SRahul Kayaith throw MLIRError("Unable to parse attribute", errors.take()); 36264eee9ef9Smax return attr; 3627436c6c9cSStella Laurenzo }, 3628b56d1ec6SPeter Hawkins nb::arg("asm"), nb::arg("context").none() = nb::none(), 36293ea4c501SRahul Kayaith "Parses an attribute from an assembly form. Raises an MLIRError on " 36303ea4c501SRahul Kayaith "failure.") 3631b56d1ec6SPeter Hawkins .def_prop_ro( 3632436c6c9cSStella Laurenzo "context", 3633436c6c9cSStella Laurenzo [](PyAttribute &self) { return self.getContext().getObject(); }, 3634436c6c9cSStella Laurenzo "Context that owns the Attribute") 3635b56d1ec6SPeter Hawkins .def_prop_ro("type", 3636b56d1ec6SPeter Hawkins [](PyAttribute &self) { return mlirAttributeGetType(self); }) 3637436c6c9cSStella Laurenzo .def( 3638436c6c9cSStella Laurenzo "get_named", 3639436c6c9cSStella Laurenzo [](PyAttribute &self, std::string name) { 3640436c6c9cSStella Laurenzo return PyNamedAttribute(self, std::move(name)); 3641436c6c9cSStella Laurenzo }, 3642b56d1ec6SPeter Hawkins nb::keep_alive<0, 1>(), "Binds a name to the attribute") 3643436c6c9cSStella Laurenzo .def("__eq__", 3644436c6c9cSStella Laurenzo [](PyAttribute &self, PyAttribute &other) { return self == other; }) 3645b56d1ec6SPeter Hawkins .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; }) 3646f78fe0b7Srkayaith .def("__hash__", 3647f78fe0b7Srkayaith [](PyAttribute &self) { 3648f78fe0b7Srkayaith return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3649f78fe0b7Srkayaith }) 3650436c6c9cSStella Laurenzo .def( 3651436c6c9cSStella Laurenzo "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 3652436c6c9cSStella Laurenzo kDumpDocstring) 3653436c6c9cSStella Laurenzo .def( 3654436c6c9cSStella Laurenzo "__str__", 3655436c6c9cSStella Laurenzo [](PyAttribute &self) { 3656436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 3657436c6c9cSStella Laurenzo mlirAttributePrint(self, printAccum.getCallback(), 3658436c6c9cSStella Laurenzo printAccum.getUserData()); 3659436c6c9cSStella Laurenzo return printAccum.join(); 3660436c6c9cSStella Laurenzo }, 3661436c6c9cSStella Laurenzo "Returns the assembly form of the Attribute.") 36629566ee28Smax .def("__repr__", 36639566ee28Smax [](PyAttribute &self) { 3664436c6c9cSStella Laurenzo // Generally, assembly formats are not printed for __repr__ because 3665436c6c9cSStella Laurenzo // this can cause exceptionally long debug output and exceptions. 36669566ee28Smax // However, attribute values are generally considered useful and 36679566ee28Smax // are printed. This may need to be re-evaluated if debug dumps end 36689566ee28Smax // up being excessive. 3669436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 3670436c6c9cSStella Laurenzo printAccum.parts.append("Attribute("); 3671436c6c9cSStella Laurenzo mlirAttributePrint(self, printAccum.getCallback(), 3672436c6c9cSStella Laurenzo printAccum.getUserData()); 3673436c6c9cSStella Laurenzo printAccum.parts.append(")"); 3674436c6c9cSStella Laurenzo return printAccum.join(); 36759566ee28Smax }) 3676b56d1ec6SPeter Hawkins .def_prop_ro("typeid", 36779566ee28Smax [](PyAttribute &self) -> MlirTypeID { 36789566ee28Smax MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); 36799566ee28Smax assert(!mlirTypeIDIsNull(mlirTypeID) && 36809566ee28Smax "mlirTypeID was expected to be non-null."); 36819566ee28Smax return mlirTypeID; 36829566ee28Smax }) 36839566ee28Smax .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) { 36849566ee28Smax MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); 36859566ee28Smax assert(!mlirTypeIDIsNull(mlirTypeID) && 36869566ee28Smax "mlirTypeID was expected to be non-null."); 3687b56d1ec6SPeter Hawkins std::optional<nb::callable> typeCaster = 36889566ee28Smax PyGlobals::get().lookupTypeCaster(mlirTypeID, 36899566ee28Smax mlirAttributeGetDialect(self)); 36909566ee28Smax if (!typeCaster) 3691b56d1ec6SPeter Hawkins return nb::cast(self); 36929566ee28Smax return typeCaster.value()(self); 3693436c6c9cSStella Laurenzo }); 3694436c6c9cSStella Laurenzo 3695436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3696436c6c9cSStella Laurenzo // Mapping of PyNamedAttribute 3697436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3698b56d1ec6SPeter Hawkins nb::class_<PyNamedAttribute>(m, "NamedAttribute") 3699436c6c9cSStella Laurenzo .def("__repr__", 3700436c6c9cSStella Laurenzo [](PyNamedAttribute &self) { 3701436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 3702436c6c9cSStella Laurenzo printAccum.parts.append("NamedAttribute("); 3703436c6c9cSStella Laurenzo printAccum.parts.append( 3704b56d1ec6SPeter Hawkins nb::str(mlirIdentifierStr(self.namedAttr.name).data, 3705120591e1SRiver Riddle mlirIdentifierStr(self.namedAttr.name).length)); 3706436c6c9cSStella Laurenzo printAccum.parts.append("="); 3707436c6c9cSStella Laurenzo mlirAttributePrint(self.namedAttr.attribute, 3708436c6c9cSStella Laurenzo printAccum.getCallback(), 3709436c6c9cSStella Laurenzo printAccum.getUserData()); 3710436c6c9cSStella Laurenzo printAccum.parts.append(")"); 3711436c6c9cSStella Laurenzo return printAccum.join(); 3712436c6c9cSStella Laurenzo }) 3713b56d1ec6SPeter Hawkins .def_prop_ro( 3714436c6c9cSStella Laurenzo "name", 3715436c6c9cSStella Laurenzo [](PyNamedAttribute &self) { 3716b56d1ec6SPeter Hawkins return nb::str(mlirIdentifierStr(self.namedAttr.name).data, 3717436c6c9cSStella Laurenzo mlirIdentifierStr(self.namedAttr.name).length); 3718436c6c9cSStella Laurenzo }, 3719436c6c9cSStella Laurenzo "The name of the NamedAttribute binding") 3720b56d1ec6SPeter Hawkins .def_prop_ro( 3721436c6c9cSStella Laurenzo "attr", 37229566ee28Smax [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, 3723b56d1ec6SPeter Hawkins nb::keep_alive<0, 1>(), 3724436c6c9cSStella Laurenzo "The underlying generic attribute of the NamedAttribute binding"); 3725436c6c9cSStella Laurenzo 3726436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3727436c6c9cSStella Laurenzo // Mapping of PyType. 3728436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3729b56d1ec6SPeter Hawkins nb::class_<PyType>(m, "Type") 3730b57d6fe4SStella Laurenzo // Delegate to the PyType copy constructor, which will also lifetime 3731b57d6fe4SStella Laurenzo // extend the backing context which owns the MlirType. 3732b56d1ec6SPeter Hawkins .def(nb::init<PyType &>(), nb::arg("cast_from_type"), 3733b57d6fe4SStella Laurenzo "Casts the passed type to the generic Type") 3734b56d1ec6SPeter Hawkins .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 3735436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 3736436c6c9cSStella Laurenzo .def_static( 3737436c6c9cSStella Laurenzo "parse", 3738436c6c9cSStella Laurenzo [](std::string typeSpec, DefaultingPyMlirContext context) { 37393ea4c501SRahul Kayaith PyMlirContext::ErrorCapture errors(context->getRef()); 3740436c6c9cSStella Laurenzo MlirType type = 3741436c6c9cSStella Laurenzo mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 37423ea4c501SRahul Kayaith if (mlirTypeIsNull(type)) 37433ea4c501SRahul Kayaith throw MLIRError("Unable to parse type", errors.take()); 3744bfb1ba75Smax return type; 3745436c6c9cSStella Laurenzo }, 3746b56d1ec6SPeter Hawkins nb::arg("asm"), nb::arg("context").none() = nb::none(), 3747436c6c9cSStella Laurenzo kContextParseTypeDocstring) 3748b56d1ec6SPeter Hawkins .def_prop_ro( 3749436c6c9cSStella Laurenzo "context", [](PyType &self) { return self.getContext().getObject(); }, 3750436c6c9cSStella Laurenzo "Context that owns the Type") 3751436c6c9cSStella Laurenzo .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 3752b56d1ec6SPeter Hawkins .def( 3753b56d1ec6SPeter Hawkins "__eq__", [](PyType &self, nb::object &other) { return false; }, 3754b56d1ec6SPeter Hawkins nb::arg("other").none()) 3755f78fe0b7Srkayaith .def("__hash__", 3756f78fe0b7Srkayaith [](PyType &self) { 3757f78fe0b7Srkayaith return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3758f78fe0b7Srkayaith }) 3759436c6c9cSStella Laurenzo .def( 3760436c6c9cSStella Laurenzo "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 3761436c6c9cSStella Laurenzo .def( 3762436c6c9cSStella Laurenzo "__str__", 3763436c6c9cSStella Laurenzo [](PyType &self) { 3764436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 3765436c6c9cSStella Laurenzo mlirTypePrint(self, printAccum.getCallback(), 3766436c6c9cSStella Laurenzo printAccum.getUserData()); 3767436c6c9cSStella Laurenzo return printAccum.join(); 3768436c6c9cSStella Laurenzo }, 3769436c6c9cSStella Laurenzo "Returns the assembly form of the type.") 3770d39a7844Smax .def("__repr__", 3771d39a7844Smax [](PyType &self) { 3772436c6c9cSStella Laurenzo // Generally, assembly formats are not printed for __repr__ because 3773436c6c9cSStella Laurenzo // this can cause exceptionally long debug output and exceptions. 3774436c6c9cSStella Laurenzo // However, types are an exception as they typically have compact 3775436c6c9cSStella Laurenzo // assembly forms and printing them is useful. 3776436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 3777436c6c9cSStella Laurenzo printAccum.parts.append("Type("); 3778d39a7844Smax mlirTypePrint(self, printAccum.getCallback(), 3779d39a7844Smax printAccum.getUserData()); 3780436c6c9cSStella Laurenzo printAccum.parts.append(")"); 3781436c6c9cSStella Laurenzo return printAccum.join(); 3782d39a7844Smax }) 3783bfb1ba75Smax .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, 3784bfb1ba75Smax [](PyType &self) { 3785bfb1ba75Smax MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); 3786bfb1ba75Smax assert(!mlirTypeIDIsNull(mlirTypeID) && 3787bfb1ba75Smax "mlirTypeID was expected to be non-null."); 3788b56d1ec6SPeter Hawkins std::optional<nb::callable> typeCaster = 3789bfb1ba75Smax PyGlobals::get().lookupTypeCaster(mlirTypeID, 3790bfb1ba75Smax mlirTypeGetDialect(self)); 3791bfb1ba75Smax if (!typeCaster) 3792b56d1ec6SPeter Hawkins return nb::cast(self); 3793bfb1ba75Smax return typeCaster.value()(self); 3794bfb1ba75Smax }) 3795b56d1ec6SPeter Hawkins .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID { 3796d39a7844Smax MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); 3797d39a7844Smax if (!mlirTypeIDIsNull(mlirTypeID)) 3798d39a7844Smax return mlirTypeID; 3799b56d1ec6SPeter Hawkins auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self))); 3800b56d1ec6SPeter Hawkins throw nb::value_error( 3801b56d1ec6SPeter Hawkins (origRepr + llvm::Twine(" has no typeid.")).str().c_str()); 3802d39a7844Smax }); 3803d39a7844Smax 3804d39a7844Smax //---------------------------------------------------------------------------- 3805d39a7844Smax // Mapping of PyTypeID. 3806d39a7844Smax //---------------------------------------------------------------------------- 3807b56d1ec6SPeter Hawkins nb::class_<PyTypeID>(m, "TypeID") 3808b56d1ec6SPeter Hawkins .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) 3809d39a7844Smax .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) 3810d39a7844Smax // Note, this tests whether the underlying TypeIDs are the same, 3811d39a7844Smax // not whether the wrapper MlirTypeIDs are the same, nor whether 3812d39a7844Smax // the Python objects are the same (i.e., PyTypeID is a value type). 3813d39a7844Smax .def("__eq__", 3814d39a7844Smax [](PyTypeID &self, PyTypeID &other) { return self == other; }) 3815d39a7844Smax .def("__eq__", 3816b56d1ec6SPeter Hawkins [](PyTypeID &self, const nb::object &other) { return false; }) 3817d39a7844Smax // Note, this gives the hash value of the underlying TypeID, not the 3818d39a7844Smax // hash value of the Python object, nor the hash value of the 3819d39a7844Smax // MlirTypeID wrapper. 3820d39a7844Smax .def("__hash__", [](PyTypeID &self) { 3821d39a7844Smax return static_cast<size_t>(mlirTypeIDHashValue(self)); 3822436c6c9cSStella Laurenzo }); 3823436c6c9cSStella Laurenzo 3824436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3825436c6c9cSStella Laurenzo // Mapping of Value. 3826436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 3827b56d1ec6SPeter Hawkins nb::class_<PyValue>(m, "Value") 3828b56d1ec6SPeter Hawkins .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value")) 3829b56d1ec6SPeter Hawkins .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 38303f3d1c90SMike Urbach .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 3831b56d1ec6SPeter Hawkins .def_prop_ro( 3832436c6c9cSStella Laurenzo "context", 3833436c6c9cSStella Laurenzo [](PyValue &self) { return self.getParentOperation()->getContext(); }, 3834436c6c9cSStella Laurenzo "Context in which the value lives.") 3835436c6c9cSStella Laurenzo .def( 3836436c6c9cSStella Laurenzo "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 3837436c6c9cSStella Laurenzo kDumpDocstring) 3838b56d1ec6SPeter Hawkins .def_prop_ro( 38395664c5e2SJohn Demme "owner", 3840b56d1ec6SPeter Hawkins [](PyValue &self) -> nb::object { 3841d747a170SJohn Demme MlirValue v = self.get(); 3842d747a170SJohn Demme if (mlirValueIsAOpResult(v)) { 3843d747a170SJohn Demme assert( 3844d747a170SJohn Demme mlirOperationEqual(self.getParentOperation()->get(), 38455664c5e2SJohn Demme mlirOpResultGetOwner(self.get())) && 38465664c5e2SJohn Demme "expected the owner of the value in Python to match that in " 38475664c5e2SJohn Demme "the IR"); 38485664c5e2SJohn Demme return self.getParentOperation().getObject(); 3849d747a170SJohn Demme } 3850d747a170SJohn Demme 3851d747a170SJohn Demme if (mlirValueIsABlockArgument(v)) { 3852d747a170SJohn Demme MlirBlock block = mlirBlockArgumentGetOwner(self.get()); 3853b56d1ec6SPeter Hawkins return nb::cast(PyBlock(self.getParentOperation(), block)); 3854d747a170SJohn Demme } 3855d747a170SJohn Demme 3856d747a170SJohn Demme assert(false && "Value must be a block argument or an op result"); 3857b56d1ec6SPeter Hawkins return nb::none(); 38585664c5e2SJohn Demme }) 3859b56d1ec6SPeter Hawkins .def_prop_ro("uses", 3860afb2ed80SMike Urbach [](PyValue &self) { 3861afb2ed80SMike Urbach return PyOpOperandIterator( 3862afb2ed80SMike Urbach mlirValueGetFirstUse(self.get())); 3863afb2ed80SMike Urbach }) 3864436c6c9cSStella Laurenzo .def("__eq__", 3865436c6c9cSStella Laurenzo [](PyValue &self, PyValue &other) { 3866436c6c9cSStella Laurenzo return self.get().ptr == other.get().ptr; 3867436c6c9cSStella Laurenzo }) 3868b56d1ec6SPeter Hawkins .def("__eq__", [](PyValue &self, nb::object other) { return false; }) 3869f78fe0b7Srkayaith .def("__hash__", 3870f78fe0b7Srkayaith [](PyValue &self) { 3871f78fe0b7Srkayaith return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3872f78fe0b7Srkayaith }) 3873436c6c9cSStella Laurenzo .def( 3874436c6c9cSStella Laurenzo "__str__", 3875436c6c9cSStella Laurenzo [](PyValue &self) { 3876436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 3877436c6c9cSStella Laurenzo printAccum.parts.append("Value("); 3878436c6c9cSStella Laurenzo mlirValuePrint(self.get(), printAccum.getCallback(), 3879436c6c9cSStella Laurenzo printAccum.getUserData()); 3880436c6c9cSStella Laurenzo printAccum.parts.append(")"); 3881436c6c9cSStella Laurenzo return printAccum.join(); 3882436c6c9cSStella Laurenzo }, 3883436c6c9cSStella Laurenzo kValueDunderStrDocstring) 388481233c70Smax .def( 388581233c70Smax "get_name", 3886f1dbfcc1SJacques Pienaar [](PyValue &self, bool useLocalScope) { 388781233c70Smax PyPrintAccumulator printAccum; 3888f1dbfcc1SJacques Pienaar MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 3889f1dbfcc1SJacques Pienaar if (useLocalScope) 389081233c70Smax mlirOpPrintingFlagsUseLocalScope(flags); 3891f1dbfcc1SJacques Pienaar MlirAsmState valueState = 3892f1dbfcc1SJacques Pienaar mlirAsmStateCreateForValue(self.get(), flags); 3893d7e49736SMaksim Levental mlirValuePrintAsOperand(self.get(), valueState, 3894d7e49736SMaksim Levental printAccum.getCallback(), 389581233c70Smax printAccum.getUserData()); 389681233c70Smax mlirOpPrintingFlagsDestroy(flags); 389775453714SJacques Pienaar mlirAsmStateDestroy(valueState); 389881233c70Smax return printAccum.join(); 389981233c70Smax }, 3900b56d1ec6SPeter Hawkins nb::arg("use_local_scope") = false) 3901f1dbfcc1SJacques Pienaar .def( 3902f1dbfcc1SJacques Pienaar "get_name", 3903b56d1ec6SPeter Hawkins [](PyValue &self, PyAsmState &state) { 3904f1dbfcc1SJacques Pienaar PyPrintAccumulator printAccum; 3905b56d1ec6SPeter Hawkins MlirAsmState valueState = state.get(); 3906f1dbfcc1SJacques Pienaar mlirValuePrintAsOperand(self.get(), valueState, 3907f1dbfcc1SJacques Pienaar printAccum.getCallback(), 3908f1dbfcc1SJacques Pienaar printAccum.getUserData()); 3909f1dbfcc1SJacques Pienaar return printAccum.join(); 3910f1dbfcc1SJacques Pienaar }, 3911b56d1ec6SPeter Hawkins nb::arg("state"), kGetNameAsOperand) 3912b56d1ec6SPeter Hawkins .def_prop_ro("type", 3913b56d1ec6SPeter Hawkins [](PyValue &self) { return mlirValueGetType(self.get()); }) 39145b303f21Smax .def( 391525b8433bSmax "set_type", 391625b8433bSmax [](PyValue &self, const PyType &type) { 391725b8433bSmax return mlirValueSetType(self.get(), type); 391825b8433bSmax }, 3919b56d1ec6SPeter Hawkins nb::arg("type")) 392025b8433bSmax .def( 39215b303f21Smax "replace_all_uses_with", 39225b303f21Smax [](PyValue &self, PyValue &with) { 39235b303f21Smax mlirValueReplaceAllUsesOfWith(self.get(), with.get()); 39245b303f21Smax }, 39257c850867SMaksim Levental kValueReplaceAllUsesWithDocstring) 392621df3251SPerry Gibson .def( 392721df3251SPerry Gibson "replace_all_uses_except", 392821df3251SPerry Gibson [](MlirValue self, MlirValue with, PyOperation &exception) { 392921df3251SPerry Gibson MlirOperation exceptedUser = exception.get(); 393021df3251SPerry Gibson mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); 393121df3251SPerry Gibson }, 3932b56d1ec6SPeter Hawkins nb::arg("with"), nb::arg("exceptions"), 393321df3251SPerry Gibson kValueReplaceAllUsesExceptDocstring) 393421df3251SPerry Gibson .def( 393521df3251SPerry Gibson "replace_all_uses_except", 3936b56d1ec6SPeter Hawkins [](MlirValue self, MlirValue with, nb::list exceptions) { 393721df3251SPerry Gibson // Convert Python list to a SmallVector of MlirOperations 393821df3251SPerry Gibson llvm::SmallVector<MlirOperation> exceptionOps; 3939b56d1ec6SPeter Hawkins for (nb::handle exception : exceptions) { 3940b56d1ec6SPeter Hawkins exceptionOps.push_back(nb::cast<PyOperation &>(exception).get()); 394121df3251SPerry Gibson } 394221df3251SPerry Gibson 394321df3251SPerry Gibson mlirValueReplaceAllUsesExcept( 394421df3251SPerry Gibson self, with, static_cast<intptr_t>(exceptionOps.size()), 394521df3251SPerry Gibson exceptionOps.data()); 394621df3251SPerry Gibson }, 3947b56d1ec6SPeter Hawkins nb::arg("with"), nb::arg("exceptions"), 394821df3251SPerry Gibson kValueReplaceAllUsesExceptDocstring) 39497c850867SMaksim Levental .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, 39507c850867SMaksim Levental [](PyValue &self) { return self.maybeDownCast(); }); 3951436c6c9cSStella Laurenzo PyBlockArgument::bind(m); 3952436c6c9cSStella Laurenzo PyOpResult::bind(m); 3953afb2ed80SMike Urbach PyOpOperand::bind(m); 3954436c6c9cSStella Laurenzo 3955b56d1ec6SPeter Hawkins nb::class_<PyAsmState>(m, "AsmState") 3956b56d1ec6SPeter Hawkins .def(nb::init<PyValue &, bool>(), nb::arg("value"), 3957b56d1ec6SPeter Hawkins nb::arg("use_local_scope") = false) 3958b56d1ec6SPeter Hawkins .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"), 3959b56d1ec6SPeter Hawkins nb::arg("use_local_scope") = false); 396075453714SJacques Pienaar 396130d61893SAlex Zinenko //---------------------------------------------------------------------------- 396230d61893SAlex Zinenko // Mapping of SymbolTable. 396330d61893SAlex Zinenko //---------------------------------------------------------------------------- 3964b56d1ec6SPeter Hawkins nb::class_<PySymbolTable>(m, "SymbolTable") 3965b56d1ec6SPeter Hawkins .def(nb::init<PyOperationBase &>()) 396630d61893SAlex Zinenko .def("__getitem__", &PySymbolTable::dunderGetItem) 3967b56d1ec6SPeter Hawkins .def("insert", &PySymbolTable::insert, nb::arg("operation")) 3968b56d1ec6SPeter Hawkins .def("erase", &PySymbolTable::erase, nb::arg("operation")) 396930d61893SAlex Zinenko .def("__delitem__", &PySymbolTable::dunderDel) 3970bdc31837SStella Laurenzo .def("__contains__", 3971bdc31837SStella Laurenzo [](PySymbolTable &table, const std::string &name) { 397230d61893SAlex Zinenko return !mlirOperationIsNull(mlirSymbolTableLookup( 397330d61893SAlex Zinenko table, mlirStringRefCreate(name.data(), name.length()))); 3974bdc31837SStella Laurenzo }) 3975bdc31837SStella Laurenzo // Static helpers. 3976bdc31837SStella Laurenzo .def_static("set_symbol_name", &PySymbolTable::setSymbolName, 3977b56d1ec6SPeter Hawkins nb::arg("symbol"), nb::arg("name")) 3978bdc31837SStella Laurenzo .def_static("get_symbol_name", &PySymbolTable::getSymbolName, 3979b56d1ec6SPeter Hawkins nb::arg("symbol")) 3980bdc31837SStella Laurenzo .def_static("get_visibility", &PySymbolTable::getVisibility, 3981b56d1ec6SPeter Hawkins nb::arg("symbol")) 3982bdc31837SStella Laurenzo .def_static("set_visibility", &PySymbolTable::setVisibility, 3983b56d1ec6SPeter Hawkins nb::arg("symbol"), nb::arg("visibility")) 3984bdc31837SStella Laurenzo .def_static("replace_all_symbol_uses", 3985b56d1ec6SPeter Hawkins &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"), 3986b56d1ec6SPeter Hawkins nb::arg("new_symbol"), nb::arg("from_op")) 3987bdc31837SStella Laurenzo .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, 3988b56d1ec6SPeter Hawkins nb::arg("from_op"), nb::arg("all_sym_uses_visible"), 3989b56d1ec6SPeter Hawkins nb::arg("callback")); 399030d61893SAlex Zinenko 3991436c6c9cSStella Laurenzo // Container bindings. 3992436c6c9cSStella Laurenzo PyBlockArgumentList::bind(m); 3993436c6c9cSStella Laurenzo PyBlockIterator::bind(m); 3994436c6c9cSStella Laurenzo PyBlockList::bind(m); 3995436c6c9cSStella Laurenzo PyOperationIterator::bind(m); 3996436c6c9cSStella Laurenzo PyOperationList::bind(m); 3997436c6c9cSStella Laurenzo PyOpAttributeMap::bind(m); 3998afb2ed80SMike Urbach PyOpOperandIterator::bind(m); 3999436c6c9cSStella Laurenzo PyOpOperandList::bind(m); 4000436c6c9cSStella Laurenzo PyOpResultList::bind(m); 4001d7e49736SMaksim Levental PyOpSuccessors::bind(m); 4002436c6c9cSStella Laurenzo PyRegionIterator::bind(m); 4003436c6c9cSStella Laurenzo PyRegionList::bind(m); 40044acd8457SAlex Zinenko 40054acd8457SAlex Zinenko // Debug bindings. 40064acd8457SAlex Zinenko PyGlobalDebugFlag::bind(m); 4007b57acb9aSJacques Pienaar 4008b57acb9aSJacques Pienaar // Attribute builder getter. 4009b57acb9aSJacques Pienaar PyAttrBuilderMap::bind(m); 40103ea4c501SRahul Kayaith 4011b56d1ec6SPeter Hawkins nb::register_exception_translator([](const std::exception_ptr &p, 4012b56d1ec6SPeter Hawkins void *payload) { 40133ea4c501SRahul Kayaith // We can't define exceptions with custom fields through pybind, so instead 40143ea4c501SRahul Kayaith // the exception class is defined in python and imported here. 40153ea4c501SRahul Kayaith try { 40163ea4c501SRahul Kayaith if (p) 40173ea4c501SRahul Kayaith std::rethrow_exception(p); 40183ea4c501SRahul Kayaith } catch (const MLIRError &e) { 4019b56d1ec6SPeter Hawkins nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) 40203ea4c501SRahul Kayaith .attr("MLIRError")(e.message, e.errorDiagnostics); 40213ea4c501SRahul Kayaith PyErr_SetObject(PyExc_Exception, obj.ptr()); 40223ea4c501SRahul Kayaith } 40233ea4c501SRahul Kayaith }); 4024436c6c9cSStella Laurenzo } 4025