1392622d0SMaksim Levental //===- PybindAdaptors.h - Interop with MLIR APIs via pybind11 -------------===// 2f13893f6SStella Laurenzo // 3f13893f6SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4f13893f6SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5f13893f6SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6f13893f6SStella Laurenzo // 7f13893f6SStella Laurenzo //===----------------------------------------------------------------------===// 8f13893f6SStella Laurenzo // This file contains adaptors for clients of the core MLIR Python APIs to 9392622d0SMaksim Levental // interop via MLIR CAPI types, using pybind11. The facilities here do not 10392622d0SMaksim Levental // depend on implementation details of the MLIR Python API and do not introduce 11392622d0SMaksim Levental // C++-level dependencies with it (requiring only Python and CAPI-level 12392622d0SMaksim Levental // dependencies). 13f13893f6SStella Laurenzo // 14f13893f6SStella Laurenzo // It is encouraged to be used both in-tree and out-of-tree. For in-tree use 15f13893f6SStella Laurenzo // cases, it should be used for dialect implementations (versus relying on 16f13893f6SStella Laurenzo // Pybind-based internals of the core libraries). 17f13893f6SStella Laurenzo //===----------------------------------------------------------------------===// 18f13893f6SStella Laurenzo 198f23296bSMehdi Amini #ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H 208f23296bSMehdi Amini #define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H 21f13893f6SStella Laurenzo 2247148832SHideto Ueno #include <pybind11/functional.h> 23f13893f6SStella Laurenzo #include <pybind11/pybind11.h> 24f13893f6SStella Laurenzo #include <pybind11/pytypes.h> 25f13893f6SStella Laurenzo #include <pybind11/stl.h> 26f13893f6SStella Laurenzo 27f13893f6SStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h" 2891f11611SOleksandr "Alex" Zinenko #include "mlir-c/Diagnostics.h" 29f13893f6SStella Laurenzo #include "mlir-c/IR.h" 30f13893f6SStella Laurenzo 31f13893f6SStella Laurenzo #include "llvm/ADT/Twine.h" 32f13893f6SStella Laurenzo 33f13893f6SStella Laurenzo namespace py = pybind11; 34bfb1ba75Smax using namespace py::literals; 35f13893f6SStella Laurenzo 36f13893f6SStella Laurenzo // Raw CAPI type casters need to be declared before use, so always include them 37f13893f6SStella Laurenzo // first. 38f13893f6SStella Laurenzo namespace pybind11 { 39f13893f6SStella Laurenzo namespace detail { 40f13893f6SStella Laurenzo 41f13893f6SStella Laurenzo /// Helper to convert a presumed MLIR API object to a capsule, accepting either 42f13893f6SStella Laurenzo /// an explicit Capsule (which can happen when two C APIs are communicating 43f13893f6SStella Laurenzo /// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR 44f13893f6SStella Laurenzo /// attribute (through which supported MLIR Python API objects export their 4522fea18eSAlex Zinenko /// contained API pointer as a capsule). Throws a type error if the object is 4622fea18eSAlex Zinenko /// neither. This is intended to be used from type casters, which are invoked 4722fea18eSAlex Zinenko /// with a raw handle (unowned). The returned object's lifetime may not extend 4822fea18eSAlex Zinenko /// beyond the apiObject handle without explicitly having its refcount increased 4922fea18eSAlex Zinenko /// (i.e. on return). 50f13893f6SStella Laurenzo static py::object mlirApiObjectToCapsule(py::handle apiObject) { 51f13893f6SStella Laurenzo if (PyCapsule_CheckExact(apiObject.ptr())) 52f13893f6SStella Laurenzo return py::reinterpret_borrow<py::object>(apiObject); 5322fea18eSAlex Zinenko if (!py::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) { 5422fea18eSAlex Zinenko auto repr = py::repr(apiObject).cast<std::string>(); 5522fea18eSAlex Zinenko throw py::type_error( 5622fea18eSAlex Zinenko (llvm::Twine("Expected an MLIR object (got ") + repr + ").").str()); 5722fea18eSAlex Zinenko } 58f13893f6SStella Laurenzo return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR); 59f13893f6SStella Laurenzo } 60f13893f6SStella Laurenzo 61f13893f6SStella Laurenzo // Note: Currently all of the following support cast from py::object to the 62f13893f6SStella Laurenzo // Mlir* C-API type, but only a few light-weight, context-bound ones 63f13893f6SStella Laurenzo // implicitly cast the other way because the use case has not yet emerged and 64f13893f6SStella Laurenzo // ownership is unclear. 65f13893f6SStella Laurenzo 66f13893f6SStella Laurenzo /// Casts object <-> MlirAffineMap. 67f13893f6SStella Laurenzo template <> 68f13893f6SStella Laurenzo struct type_caster<MlirAffineMap> { 69f13893f6SStella Laurenzo PYBIND11_TYPE_CASTER(MlirAffineMap, _("MlirAffineMap")); 70f13893f6SStella Laurenzo bool load(handle src, bool) { 71f13893f6SStella Laurenzo py::object capsule = mlirApiObjectToCapsule(src); 72f13893f6SStella Laurenzo value = mlirPythonCapsuleToAffineMap(capsule.ptr()); 73f13893f6SStella Laurenzo if (mlirAffineMapIsNull(value)) { 74f13893f6SStella Laurenzo return false; 75f13893f6SStella Laurenzo } 76f13893f6SStella Laurenzo return !mlirAffineMapIsNull(value); 77f13893f6SStella Laurenzo } 78f13893f6SStella Laurenzo static handle cast(MlirAffineMap v, return_value_policy, handle) { 79f13893f6SStella Laurenzo py::object capsule = 80f13893f6SStella Laurenzo py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(v)); 81e78b745cSStella Laurenzo return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 82f13893f6SStella Laurenzo .attr("AffineMap") 83f13893f6SStella Laurenzo .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 84f13893f6SStella Laurenzo .release(); 85f13893f6SStella Laurenzo } 86f13893f6SStella Laurenzo }; 87f13893f6SStella Laurenzo 88f13893f6SStella Laurenzo /// Casts object <-> MlirAttribute. 89f13893f6SStella Laurenzo template <> 90f13893f6SStella Laurenzo struct type_caster<MlirAttribute> { 91f13893f6SStella Laurenzo PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute")); 92f13893f6SStella Laurenzo bool load(handle src, bool) { 93f13893f6SStella Laurenzo py::object capsule = mlirApiObjectToCapsule(src); 94f13893f6SStella Laurenzo value = mlirPythonCapsuleToAttribute(capsule.ptr()); 9558ec17cbSMehdi Amini return !mlirAttributeIsNull(value); 96f13893f6SStella Laurenzo } 97f13893f6SStella Laurenzo static handle cast(MlirAttribute v, return_value_policy, handle) { 98f13893f6SStella Laurenzo py::object capsule = 99f13893f6SStella Laurenzo py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(v)); 100e78b745cSStella Laurenzo return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 101f13893f6SStella Laurenzo .attr("Attribute") 102f13893f6SStella Laurenzo .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 1039566ee28Smax .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() 104f13893f6SStella Laurenzo .release(); 105f13893f6SStella Laurenzo } 106f13893f6SStella Laurenzo }; 107f13893f6SStella Laurenzo 108c83318e3SAdam Paszke /// Casts object -> MlirBlock. 109c83318e3SAdam Paszke template <> 110c83318e3SAdam Paszke struct type_caster<MlirBlock> { 111c83318e3SAdam Paszke PYBIND11_TYPE_CASTER(MlirBlock, _("MlirBlock")); 112c83318e3SAdam Paszke bool load(handle src, bool) { 113c83318e3SAdam Paszke py::object capsule = mlirApiObjectToCapsule(src); 114c83318e3SAdam Paszke value = mlirPythonCapsuleToBlock(capsule.ptr()); 115c83318e3SAdam Paszke return !mlirBlockIsNull(value); 116c83318e3SAdam Paszke } 117c83318e3SAdam Paszke }; 118c83318e3SAdam Paszke 119f13893f6SStella Laurenzo /// Casts object -> MlirContext. 120f13893f6SStella Laurenzo template <> 121f13893f6SStella Laurenzo struct type_caster<MlirContext> { 122f13893f6SStella Laurenzo PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext")); 123f13893f6SStella Laurenzo bool load(handle src, bool) { 124f13893f6SStella Laurenzo if (src.is_none()) { 125f13893f6SStella Laurenzo // Gets the current thread-bound context. 126f13893f6SStella Laurenzo // TODO: This raises an error of "No current context" currently. 127f13893f6SStella Laurenzo // Update the implementation to pretty-print the helpful error that the 128f13893f6SStella Laurenzo // core implementations print in this case. 129e78b745cSStella Laurenzo src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 130f13893f6SStella Laurenzo .attr("Context") 131f13893f6SStella Laurenzo .attr("current"); 132f13893f6SStella Laurenzo } 133f13893f6SStella Laurenzo py::object capsule = mlirApiObjectToCapsule(src); 134f13893f6SStella Laurenzo value = mlirPythonCapsuleToContext(capsule.ptr()); 13558ec17cbSMehdi Amini return !mlirContextIsNull(value); 136f13893f6SStella Laurenzo } 137f13893f6SStella Laurenzo }; 138f13893f6SStella Laurenzo 1395e83a5b4SStella Laurenzo /// Casts object <-> MlirDialectRegistry. 1405e83a5b4SStella Laurenzo template <> 1415e83a5b4SStella Laurenzo struct type_caster<MlirDialectRegistry> { 1425e83a5b4SStella Laurenzo PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry")); 1435e83a5b4SStella Laurenzo bool load(handle src, bool) { 1445e83a5b4SStella Laurenzo py::object capsule = mlirApiObjectToCapsule(src); 1455e83a5b4SStella Laurenzo value = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); 1465e83a5b4SStella Laurenzo return !mlirDialectRegistryIsNull(value); 1475e83a5b4SStella Laurenzo } 1485e83a5b4SStella Laurenzo static handle cast(MlirDialectRegistry v, return_value_policy, handle) { 1495e83a5b4SStella Laurenzo py::object capsule = py::reinterpret_steal<py::object>( 1505e83a5b4SStella Laurenzo mlirPythonDialectRegistryToCapsule(v)); 1515e83a5b4SStella Laurenzo return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 1525e83a5b4SStella Laurenzo .attr("DialectRegistry") 1535e83a5b4SStella Laurenzo .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 1545e83a5b4SStella Laurenzo .release(); 1555e83a5b4SStella Laurenzo } 1565e83a5b4SStella Laurenzo }; 1575e83a5b4SStella Laurenzo 158f13893f6SStella Laurenzo /// Casts object <-> MlirLocation. 159f13893f6SStella Laurenzo template <> 160f13893f6SStella Laurenzo struct type_caster<MlirLocation> { 161f13893f6SStella Laurenzo PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation")); 162f13893f6SStella Laurenzo bool load(handle src, bool) { 16394d96c2aSJohn Demme if (src.is_none()) { 16494d96c2aSJohn Demme // Gets the current thread-bound context. 16594d96c2aSJohn Demme src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 16694d96c2aSJohn Demme .attr("Location") 16794d96c2aSJohn Demme .attr("current"); 16894d96c2aSJohn Demme } 169f13893f6SStella Laurenzo py::object capsule = mlirApiObjectToCapsule(src); 170f13893f6SStella Laurenzo value = mlirPythonCapsuleToLocation(capsule.ptr()); 17158ec17cbSMehdi Amini return !mlirLocationIsNull(value); 172f13893f6SStella Laurenzo } 173f13893f6SStella Laurenzo static handle cast(MlirLocation v, return_value_policy, handle) { 174f13893f6SStella Laurenzo py::object capsule = 175f13893f6SStella Laurenzo py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(v)); 176e78b745cSStella Laurenzo return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 177f13893f6SStella Laurenzo .attr("Location") 178f13893f6SStella Laurenzo .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 179f13893f6SStella Laurenzo .release(); 180f13893f6SStella Laurenzo } 181f13893f6SStella Laurenzo }; 182f13893f6SStella Laurenzo 183f13893f6SStella Laurenzo /// Casts object <-> MlirModule. 184f13893f6SStella Laurenzo template <> 185f13893f6SStella Laurenzo struct type_caster<MlirModule> { 186f13893f6SStella Laurenzo PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule")); 187f13893f6SStella Laurenzo bool load(handle src, bool) { 188f13893f6SStella Laurenzo py::object capsule = mlirApiObjectToCapsule(src); 189f13893f6SStella Laurenzo value = mlirPythonCapsuleToModule(capsule.ptr()); 19058ec17cbSMehdi Amini return !mlirModuleIsNull(value); 191f13893f6SStella Laurenzo } 192f13893f6SStella Laurenzo static handle cast(MlirModule v, return_value_policy, handle) { 193f13893f6SStella Laurenzo py::object capsule = 194f13893f6SStella Laurenzo py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(v)); 195e78b745cSStella Laurenzo return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 196f13893f6SStella Laurenzo .attr("Module") 197f13893f6SStella Laurenzo .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 198f13893f6SStella Laurenzo .release(); 199f13893f6SStella Laurenzo }; 200f13893f6SStella Laurenzo }; 201f13893f6SStella Laurenzo 20218cf1cd9SJacques Pienaar /// Casts object <-> MlirFrozenRewritePatternSet. 20318cf1cd9SJacques Pienaar template <> 20418cf1cd9SJacques Pienaar struct type_caster<MlirFrozenRewritePatternSet> { 20518cf1cd9SJacques Pienaar PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet, 20618cf1cd9SJacques Pienaar _("MlirFrozenRewritePatternSet")); 20718cf1cd9SJacques Pienaar bool load(handle src, bool) { 20818cf1cd9SJacques Pienaar py::object capsule = mlirApiObjectToCapsule(src); 20918cf1cd9SJacques Pienaar value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); 21018cf1cd9SJacques Pienaar return value.ptr != nullptr; 21118cf1cd9SJacques Pienaar } 21218cf1cd9SJacques Pienaar static handle cast(MlirFrozenRewritePatternSet v, return_value_policy, 21318cf1cd9SJacques Pienaar handle) { 21418cf1cd9SJacques Pienaar py::object capsule = py::reinterpret_steal<py::object>( 21518cf1cd9SJacques Pienaar mlirPythonFrozenRewritePatternSetToCapsule(v)); 21618cf1cd9SJacques Pienaar return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite")) 21718cf1cd9SJacques Pienaar .attr("FrozenRewritePatternSet") 21818cf1cd9SJacques Pienaar .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 21918cf1cd9SJacques Pienaar .release(); 22018cf1cd9SJacques Pienaar }; 22118cf1cd9SJacques Pienaar }; 22218cf1cd9SJacques Pienaar 223f13893f6SStella Laurenzo /// Casts object <-> MlirOperation. 224f13893f6SStella Laurenzo template <> 225f13893f6SStella Laurenzo struct type_caster<MlirOperation> { 226f13893f6SStella Laurenzo PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation")); 227f13893f6SStella Laurenzo bool load(handle src, bool) { 228f13893f6SStella Laurenzo py::object capsule = mlirApiObjectToCapsule(src); 229f13893f6SStella Laurenzo value = mlirPythonCapsuleToOperation(capsule.ptr()); 23058ec17cbSMehdi Amini return !mlirOperationIsNull(value); 231f13893f6SStella Laurenzo } 232f13893f6SStella Laurenzo static handle cast(MlirOperation v, return_value_policy, handle) { 233f13893f6SStella Laurenzo if (v.ptr == nullptr) 234f13893f6SStella Laurenzo return py::none(); 235f13893f6SStella Laurenzo py::object capsule = 236f13893f6SStella Laurenzo py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(v)); 237e78b745cSStella Laurenzo return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 238f13893f6SStella Laurenzo .attr("Operation") 239f13893f6SStella Laurenzo .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 240f13893f6SStella Laurenzo .release(); 241f13893f6SStella Laurenzo }; 242f13893f6SStella Laurenzo }; 243f13893f6SStella Laurenzo 2445e9b6a22SJohn Demme /// Casts object <-> MlirValue. 2455e9b6a22SJohn Demme template <> 2465e9b6a22SJohn Demme struct type_caster<MlirValue> { 2475e9b6a22SJohn Demme PYBIND11_TYPE_CASTER(MlirValue, _("MlirValue")); 2485e9b6a22SJohn Demme bool load(handle src, bool) { 2495e9b6a22SJohn Demme py::object capsule = mlirApiObjectToCapsule(src); 2505e9b6a22SJohn Demme value = mlirPythonCapsuleToValue(capsule.ptr()); 2515e9b6a22SJohn Demme return !mlirValueIsNull(value); 2525e9b6a22SJohn Demme } 2535e9b6a22SJohn Demme static handle cast(MlirValue v, return_value_policy, handle) { 2545e9b6a22SJohn Demme if (v.ptr == nullptr) 2555e9b6a22SJohn Demme return py::none(); 2565e9b6a22SJohn Demme py::object capsule = 2575e9b6a22SJohn Demme py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(v)); 2585e9b6a22SJohn Demme return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 2595e9b6a22SJohn Demme .attr("Value") 2605e9b6a22SJohn Demme .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 2617c850867SMaksim Levental .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() 2625e9b6a22SJohn Demme .release(); 2635e9b6a22SJohn Demme }; 2645e9b6a22SJohn Demme }; 2655e9b6a22SJohn Demme 266f13893f6SStella Laurenzo /// Casts object -> MlirPassManager. 267f13893f6SStella Laurenzo template <> 268f13893f6SStella Laurenzo struct type_caster<MlirPassManager> { 269f13893f6SStella Laurenzo PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager")); 270f13893f6SStella Laurenzo bool load(handle src, bool) { 271f13893f6SStella Laurenzo py::object capsule = mlirApiObjectToCapsule(src); 272f13893f6SStella Laurenzo value = mlirPythonCapsuleToPassManager(capsule.ptr()); 27358ec17cbSMehdi Amini return !mlirPassManagerIsNull(value); 274f13893f6SStella Laurenzo } 275f13893f6SStella Laurenzo }; 276f13893f6SStella Laurenzo 277d39a7844Smax /// Casts object <-> MlirTypeID. 278d39a7844Smax template <> 279d39a7844Smax struct type_caster<MlirTypeID> { 280d39a7844Smax PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID")); 281d39a7844Smax bool load(handle src, bool) { 282d39a7844Smax py::object capsule = mlirApiObjectToCapsule(src); 283d39a7844Smax value = mlirPythonCapsuleToTypeID(capsule.ptr()); 284d39a7844Smax return !mlirTypeIDIsNull(value); 285d39a7844Smax } 286d39a7844Smax static handle cast(MlirTypeID v, return_value_policy, handle) { 287d39a7844Smax if (v.ptr == nullptr) 288d39a7844Smax return py::none(); 289d39a7844Smax py::object capsule = 290d39a7844Smax py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(v)); 291d39a7844Smax return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 292d39a7844Smax .attr("TypeID") 293d39a7844Smax .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 294d39a7844Smax .release(); 295d39a7844Smax }; 296d39a7844Smax }; 297d39a7844Smax 298f13893f6SStella Laurenzo /// Casts object <-> MlirType. 299f13893f6SStella Laurenzo template <> 300f13893f6SStella Laurenzo struct type_caster<MlirType> { 301f13893f6SStella Laurenzo PYBIND11_TYPE_CASTER(MlirType, _("MlirType")); 302f13893f6SStella Laurenzo bool load(handle src, bool) { 303f13893f6SStella Laurenzo py::object capsule = mlirApiObjectToCapsule(src); 304f13893f6SStella Laurenzo value = mlirPythonCapsuleToType(capsule.ptr()); 30558ec17cbSMehdi Amini return !mlirTypeIsNull(value); 306f13893f6SStella Laurenzo } 307f13893f6SStella Laurenzo static handle cast(MlirType t, return_value_policy, handle) { 308f13893f6SStella Laurenzo py::object capsule = 309f13893f6SStella Laurenzo py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(t)); 310e78b745cSStella Laurenzo return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 311f13893f6SStella Laurenzo .attr("Type") 312f13893f6SStella Laurenzo .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 313bfb1ba75Smax .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() 314f13893f6SStella Laurenzo .release(); 315f13893f6SStella Laurenzo } 316f13893f6SStella Laurenzo }; 317f13893f6SStella Laurenzo 318f13893f6SStella Laurenzo } // namespace detail 319f13893f6SStella Laurenzo } // namespace pybind11 320f13893f6SStella Laurenzo 321f13893f6SStella Laurenzo namespace mlir { 322f13893f6SStella Laurenzo namespace python { 323f13893f6SStella Laurenzo namespace adaptors { 324f13893f6SStella Laurenzo 325f13893f6SStella Laurenzo /// Provides a facility like py::class_ for defining a new class in a scope, 326f13893f6SStella Laurenzo /// but this allows extension of an arbitrary Python class, defining methods 327f13893f6SStella Laurenzo /// on it is a similar way. Classes defined in this way are very similar to 328f13893f6SStella Laurenzo /// if defined in Python in the usual way but use Pybind11 machinery to do 329f13893f6SStella Laurenzo /// it. These are not "real" Pybind11 classes but pure Python classes with no 330f13893f6SStella Laurenzo /// relation to a concrete C++ class. 331f13893f6SStella Laurenzo /// 332f13893f6SStella Laurenzo /// Derived from a discussion upstream: 333f13893f6SStella Laurenzo /// https://github.com/pybind/pybind11/issues/1193 334f13893f6SStella Laurenzo /// (plus a fair amount of extra curricular poking) 335f13893f6SStella Laurenzo /// TODO: If this proves useful, see about including it in pybind11. 336f13893f6SStella Laurenzo class pure_subclass { 337f13893f6SStella Laurenzo public: 338f13893f6SStella Laurenzo pure_subclass(py::handle scope, const char *derivedClassName, 339e8d07395SMehdi Amini const py::object &superClass) { 340f13893f6SStella Laurenzo py::object pyType = 341f13893f6SStella Laurenzo py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 342f13893f6SStella Laurenzo py::object metaclass = pyType(superClass); 343f13893f6SStella Laurenzo py::dict attributes; 344f13893f6SStella Laurenzo 345f13893f6SStella Laurenzo thisClass = 346f13893f6SStella Laurenzo metaclass(derivedClassName, py::make_tuple(superClass), attributes); 347f13893f6SStella Laurenzo scope.attr(derivedClassName) = thisClass; 348f13893f6SStella Laurenzo } 349f13893f6SStella Laurenzo 350f13893f6SStella Laurenzo template <typename Func, typename... Extra> 351f13893f6SStella Laurenzo pure_subclass &def(const char *name, Func &&f, const Extra &...extra) { 352f13893f6SStella Laurenzo py::cpp_function cf( 3538b734798SAlex Zinenko std::forward<Func>(f), py::name(name), py::is_method(thisClass), 354f13893f6SStella Laurenzo py::sibling(py::getattr(thisClass, name, py::none())), extra...); 355f13893f6SStella Laurenzo thisClass.attr(cf.name()) = cf; 356f13893f6SStella Laurenzo return *this; 357f13893f6SStella Laurenzo } 358f13893f6SStella Laurenzo 359f13893f6SStella Laurenzo template <typename Func, typename... Extra> 360f13893f6SStella Laurenzo pure_subclass &def_property_readonly(const char *name, Func &&f, 361f13893f6SStella Laurenzo const Extra &...extra) { 362f13893f6SStella Laurenzo py::cpp_function cf( 3638b734798SAlex Zinenko std::forward<Func>(f), py::name(name), py::is_method(thisClass), 364f13893f6SStella Laurenzo py::sibling(py::getattr(thisClass, name, py::none())), extra...); 365f13893f6SStella Laurenzo auto builtinProperty = 366f13893f6SStella Laurenzo py::reinterpret_borrow<py::object>((PyObject *)&PyProperty_Type); 367f13893f6SStella Laurenzo thisClass.attr(name) = builtinProperty(cf); 368f13893f6SStella Laurenzo return *this; 369f13893f6SStella Laurenzo } 370f13893f6SStella Laurenzo 371f13893f6SStella Laurenzo template <typename Func, typename... Extra> 372f13893f6SStella Laurenzo pure_subclass &def_staticmethod(const char *name, Func &&f, 373f13893f6SStella Laurenzo const Extra &...extra) { 374f13893f6SStella Laurenzo static_assert(!std::is_member_function_pointer<Func>::value, 375f13893f6SStella Laurenzo "def_staticmethod(...) called with a non-static member " 376f13893f6SStella Laurenzo "function pointer"); 377*b56d1ec6SPeter Hawkins py::cpp_function cf(std::forward<Func>(f), py::name(name), 378*b56d1ec6SPeter Hawkins py::scope(thisClass), extra...); 379f13893f6SStella Laurenzo thisClass.attr(cf.name()) = py::staticmethod(cf); 380f13893f6SStella Laurenzo return *this; 381f13893f6SStella Laurenzo } 382f13893f6SStella Laurenzo 383f13893f6SStella Laurenzo template <typename Func, typename... Extra> 384f13893f6SStella Laurenzo pure_subclass &def_classmethod(const char *name, Func &&f, 385f13893f6SStella Laurenzo const Extra &...extra) { 386f13893f6SStella Laurenzo static_assert(!std::is_member_function_pointer<Func>::value, 387f13893f6SStella Laurenzo "def_classmethod(...) called with a non-static member " 388f13893f6SStella Laurenzo "function pointer"); 389*b56d1ec6SPeter Hawkins py::cpp_function cf(std::forward<Func>(f), py::name(name), 390*b56d1ec6SPeter Hawkins py::scope(thisClass), extra...); 391f13893f6SStella Laurenzo thisClass.attr(cf.name()) = 392f13893f6SStella Laurenzo py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr())); 393f13893f6SStella Laurenzo return *this; 394f13893f6SStella Laurenzo } 395f13893f6SStella Laurenzo 39666d4090dSAlex Zinenko py::object get_class() const { return thisClass; } 39766d4090dSAlex Zinenko 398f13893f6SStella Laurenzo protected: 399f13893f6SStella Laurenzo py::object superClass; 400f13893f6SStella Laurenzo py::object thisClass; 401f13893f6SStella Laurenzo }; 402f13893f6SStella Laurenzo 403f13893f6SStella Laurenzo /// Creates a custom subclass of mlir.ir.Attribute, implementing a casting 404f13893f6SStella Laurenzo /// constructor and type checking methods. 405f13893f6SStella Laurenzo class mlir_attribute_subclass : public pure_subclass { 406f13893f6SStella Laurenzo public: 407f13893f6SStella Laurenzo using IsAFunctionTy = bool (*)(MlirAttribute); 40893156458SMaksim Levental using GetTypeIDFunctionTy = MlirTypeID (*)(); 409f13893f6SStella Laurenzo 410f13893f6SStella Laurenzo /// Subclasses by looking up the super-class dynamically. 411f13893f6SStella Laurenzo mlir_attribute_subclass(py::handle scope, const char *attrClassName, 41293156458SMaksim Levental IsAFunctionTy isaFunction, 41393156458SMaksim Levental GetTypeIDFunctionTy getTypeIDFunction = nullptr) 414f13893f6SStella Laurenzo : mlir_attribute_subclass( 415f13893f6SStella Laurenzo scope, attrClassName, isaFunction, 416e78b745cSStella Laurenzo py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 41793156458SMaksim Levental .attr("Attribute"), 41893156458SMaksim Levental getTypeIDFunction) {} 419f13893f6SStella Laurenzo 420f13893f6SStella Laurenzo /// Subclasses with a provided mlir.ir.Attribute super-class. This must 421f13893f6SStella Laurenzo /// be used if the subclass is being defined in the same extension module 422f13893f6SStella Laurenzo /// as the mlir.ir class (otherwise, it will trigger a recursive 423f13893f6SStella Laurenzo /// initialization). 424f13893f6SStella Laurenzo mlir_attribute_subclass(py::handle scope, const char *typeClassName, 42593156458SMaksim Levental IsAFunctionTy isaFunction, const py::object &superCls, 42693156458SMaksim Levental GetTypeIDFunctionTy getTypeIDFunction = nullptr) 42789a92fb3SAlex Zinenko : pure_subclass(scope, typeClassName, superCls) { 42889a92fb3SAlex Zinenko // Casting constructor. Note that it hard, if not impossible, to properly 42989a92fb3SAlex Zinenko // call chain to parent `__init__` in pybind11 due to its special handling 43089a92fb3SAlex Zinenko // for init functions that don't have a fully constructed self-reference, 43189a92fb3SAlex Zinenko // which makes it impossible to forward it to `__init__` of a superclass. 43289a92fb3SAlex Zinenko // Instead, provide a custom `__new__` and call that of a superclass, which 43389a92fb3SAlex Zinenko // eventually calls `__init__` of the superclass. Since attribute subclasses 43489a92fb3SAlex Zinenko // have no additional members, we can just return the instance thus created 43589a92fb3SAlex Zinenko // without amending it. 436f13893f6SStella Laurenzo std::string captureTypeName( 437f13893f6SStella Laurenzo typeClassName); // As string in case if typeClassName is not static. 43889a92fb3SAlex Zinenko py::cpp_function newCf( 43989a92fb3SAlex Zinenko [superCls, isaFunction, captureTypeName](py::object cls, 44089a92fb3SAlex Zinenko py::object otherAttribute) { 44189a92fb3SAlex Zinenko MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherAttribute); 442f13893f6SStella Laurenzo if (!isaFunction(rawAttribute)) { 44389a92fb3SAlex Zinenko auto origRepr = py::repr(otherAttribute).cast<std::string>(); 444f13893f6SStella Laurenzo throw std::invalid_argument( 445f13893f6SStella Laurenzo (llvm::Twine("Cannot cast attribute to ") + captureTypeName + 446f13893f6SStella Laurenzo " (from " + origRepr + ")") 447f13893f6SStella Laurenzo .str()); 448f13893f6SStella Laurenzo } 44989a92fb3SAlex Zinenko py::object self = superCls.attr("__new__")(cls, otherAttribute); 45089a92fb3SAlex Zinenko return self; 451f13893f6SStella Laurenzo }, 45289a92fb3SAlex Zinenko py::name("__new__"), py::arg("cls"), py::arg("cast_from_attr")); 45389a92fb3SAlex Zinenko thisClass.attr("__new__") = newCf; 454f13893f6SStella Laurenzo 455f13893f6SStella Laurenzo // 'isinstance' method. 456f13893f6SStella Laurenzo def_staticmethod( 457f13893f6SStella Laurenzo "isinstance", 458f13893f6SStella Laurenzo [isaFunction](MlirAttribute other) { return isaFunction(other); }, 459f13893f6SStella Laurenzo py::arg("other_attribute")); 46093156458SMaksim Levental def("__repr__", [superCls, captureTypeName](py::object self) { 46193156458SMaksim Levental return py::repr(superCls(self)) 46293156458SMaksim Levental .attr("replace")(superCls.attr("__name__"), captureTypeName); 46393156458SMaksim Levental }); 46493156458SMaksim Levental if (getTypeIDFunction) { 46593156458SMaksim Levental def_staticmethod("get_static_typeid", 46693156458SMaksim Levental [getTypeIDFunction]() { return getTypeIDFunction(); }); 46793156458SMaksim Levental py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 46893156458SMaksim Levental .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( 46993156458SMaksim Levental getTypeIDFunction())(pybind11::cpp_function( 47093156458SMaksim Levental [thisClass = thisClass](const py::object &mlirAttribute) { 47193156458SMaksim Levental return thisClass(mlirAttribute); 47293156458SMaksim Levental })); 47393156458SMaksim Levental } 474f13893f6SStella Laurenzo } 475f13893f6SStella Laurenzo }; 476f13893f6SStella Laurenzo 477f13893f6SStella Laurenzo /// Creates a custom subclass of mlir.ir.Type, implementing a casting 478f13893f6SStella Laurenzo /// constructor and type checking methods. 479f13893f6SStella Laurenzo class mlir_type_subclass : public pure_subclass { 480f13893f6SStella Laurenzo public: 481f13893f6SStella Laurenzo using IsAFunctionTy = bool (*)(MlirType); 482bfb1ba75Smax using GetTypeIDFunctionTy = MlirTypeID (*)(); 483f13893f6SStella Laurenzo 484f13893f6SStella Laurenzo /// Subclasses by looking up the super-class dynamically. 485f13893f6SStella Laurenzo mlir_type_subclass(py::handle scope, const char *typeClassName, 486bfb1ba75Smax IsAFunctionTy isaFunction, 487bfb1ba75Smax GetTypeIDFunctionTy getTypeIDFunction = nullptr) 488f13893f6SStella Laurenzo : mlir_type_subclass( 489f13893f6SStella Laurenzo scope, typeClassName, isaFunction, 490bfb1ba75Smax py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type"), 491bfb1ba75Smax getTypeIDFunction) {} 492f13893f6SStella Laurenzo 493f13893f6SStella Laurenzo /// Subclasses with a provided mlir.ir.Type super-class. This must 494f13893f6SStella Laurenzo /// be used if the subclass is being defined in the same extension module 495f13893f6SStella Laurenzo /// as the mlir.ir class (otherwise, it will trigger a recursive 496f13893f6SStella Laurenzo /// initialization). 497f13893f6SStella Laurenzo mlir_type_subclass(py::handle scope, const char *typeClassName, 498bfb1ba75Smax IsAFunctionTy isaFunction, const py::object &superCls, 499bfb1ba75Smax GetTypeIDFunctionTy getTypeIDFunction = nullptr) 50089a92fb3SAlex Zinenko : pure_subclass(scope, typeClassName, superCls) { 50189a92fb3SAlex Zinenko // Casting constructor. Note that it hard, if not impossible, to properly 50289a92fb3SAlex Zinenko // call chain to parent `__init__` in pybind11 due to its special handling 50389a92fb3SAlex Zinenko // for init functions that don't have a fully constructed self-reference, 50489a92fb3SAlex Zinenko // which makes it impossible to forward it to `__init__` of a superclass. 50589a92fb3SAlex Zinenko // Instead, provide a custom `__new__` and call that of a superclass, which 50689a92fb3SAlex Zinenko // eventually calls `__init__` of the superclass. Since attribute subclasses 50789a92fb3SAlex Zinenko // have no additional members, we can just return the instance thus created 50889a92fb3SAlex Zinenko // without amending it. 509f13893f6SStella Laurenzo std::string captureTypeName( 510f13893f6SStella Laurenzo typeClassName); // As string in case if typeClassName is not static. 51189a92fb3SAlex Zinenko py::cpp_function newCf( 51289a92fb3SAlex Zinenko [superCls, isaFunction, captureTypeName](py::object cls, 5137b1ceee6SAlex Zinenko py::object otherType) { 514f13893f6SStella Laurenzo MlirType rawType = py::cast<MlirType>(otherType); 515f13893f6SStella Laurenzo if (!isaFunction(rawType)) { 516f13893f6SStella Laurenzo auto origRepr = py::repr(otherType).cast<std::string>(); 517f13893f6SStella Laurenzo throw std::invalid_argument((llvm::Twine("Cannot cast type to ") + 518f13893f6SStella Laurenzo captureTypeName + " (from " + 519f13893f6SStella Laurenzo origRepr + ")") 520f13893f6SStella Laurenzo .str()); 521f13893f6SStella Laurenzo } 52289a92fb3SAlex Zinenko py::object self = superCls.attr("__new__")(cls, otherType); 52389a92fb3SAlex Zinenko return self; 524f13893f6SStella Laurenzo }, 52589a92fb3SAlex Zinenko py::name("__new__"), py::arg("cls"), py::arg("cast_from_type")); 52689a92fb3SAlex Zinenko thisClass.attr("__new__") = newCf; 527f13893f6SStella Laurenzo 528f13893f6SStella Laurenzo // 'isinstance' method. 529f13893f6SStella Laurenzo def_staticmethod( 530f13893f6SStella Laurenzo "isinstance", 531f13893f6SStella Laurenzo [isaFunction](MlirType other) { return isaFunction(other); }, 532f13893f6SStella Laurenzo py::arg("other_type")); 533bfb1ba75Smax def("__repr__", [superCls, captureTypeName](py::object self) { 534bfb1ba75Smax return py::repr(superCls(self)) 535bfb1ba75Smax .attr("replace")(superCls.attr("__name__"), captureTypeName); 536bfb1ba75Smax }); 537bfb1ba75Smax if (getTypeIDFunction) { 538681eacc1Smartin-luecke // 'get_static_typeid' method. 539681eacc1Smartin-luecke // This is modeled as a static method instead of a static property because 540681eacc1Smartin-luecke // `def_property_readonly_static` is not available in `pure_subclass` and 541681eacc1Smartin-luecke // we do not want to introduce the complexity that pybind uses to 542681eacc1Smartin-luecke // implement it. 543681eacc1Smartin-luecke def_staticmethod("get_static_typeid", 544681eacc1Smartin-luecke [getTypeIDFunction]() { return getTypeIDFunction(); }); 545bfb1ba75Smax py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 546bfb1ba75Smax .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( 5477c850867SMaksim Levental getTypeIDFunction())(pybind11::cpp_function( 548bfb1ba75Smax [thisClass = thisClass](const py::object &mlirType) { 549bfb1ba75Smax return thisClass(mlirType); 550bfb1ba75Smax })); 551bfb1ba75Smax } 552f13893f6SStella Laurenzo } 553f13893f6SStella Laurenzo }; 554f13893f6SStella Laurenzo 55569cc3cfbSmax /// Creates a custom subclass of mlir.ir.Value, implementing a casting 55669cc3cfbSmax /// constructor and type checking methods. 55769cc3cfbSmax class mlir_value_subclass : public pure_subclass { 55869cc3cfbSmax public: 55969cc3cfbSmax using IsAFunctionTy = bool (*)(MlirValue); 56069cc3cfbSmax 56169cc3cfbSmax /// Subclasses by looking up the super-class dynamically. 56269cc3cfbSmax mlir_value_subclass(py::handle scope, const char *valueClassName, 56369cc3cfbSmax IsAFunctionTy isaFunction) 56469cc3cfbSmax : mlir_value_subclass( 56569cc3cfbSmax scope, valueClassName, isaFunction, 56669cc3cfbSmax py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Value")) { 56769cc3cfbSmax } 56869cc3cfbSmax 56969cc3cfbSmax /// Subclasses with a provided mlir.ir.Value super-class. This must 57069cc3cfbSmax /// be used if the subclass is being defined in the same extension module 57169cc3cfbSmax /// as the mlir.ir class (otherwise, it will trigger a recursive 57269cc3cfbSmax /// initialization). 57369cc3cfbSmax mlir_value_subclass(py::handle scope, const char *valueClassName, 57469cc3cfbSmax IsAFunctionTy isaFunction, const py::object &superCls) 57569cc3cfbSmax : pure_subclass(scope, valueClassName, superCls) { 57669cc3cfbSmax // Casting constructor. Note that it hard, if not impossible, to properly 57769cc3cfbSmax // call chain to parent `__init__` in pybind11 due to its special handling 57869cc3cfbSmax // for init functions that don't have a fully constructed self-reference, 57969cc3cfbSmax // which makes it impossible to forward it to `__init__` of a superclass. 58069cc3cfbSmax // Instead, provide a custom `__new__` and call that of a superclass, which 58169cc3cfbSmax // eventually calls `__init__` of the superclass. Since attribute subclasses 58269cc3cfbSmax // have no additional members, we can just return the instance thus created 58369cc3cfbSmax // without amending it. 58469cc3cfbSmax std::string captureValueName( 58569cc3cfbSmax valueClassName); // As string in case if valueClassName is not static. 58669cc3cfbSmax py::cpp_function newCf( 58769cc3cfbSmax [superCls, isaFunction, captureValueName](py::object cls, 58869cc3cfbSmax py::object otherValue) { 58969cc3cfbSmax MlirValue rawValue = py::cast<MlirValue>(otherValue); 59069cc3cfbSmax if (!isaFunction(rawValue)) { 59169cc3cfbSmax auto origRepr = py::repr(otherValue).cast<std::string>(); 59269cc3cfbSmax throw std::invalid_argument((llvm::Twine("Cannot cast value to ") + 59369cc3cfbSmax captureValueName + " (from " + 59469cc3cfbSmax origRepr + ")") 59569cc3cfbSmax .str()); 59669cc3cfbSmax } 59769cc3cfbSmax py::object self = superCls.attr("__new__")(cls, otherValue); 59869cc3cfbSmax return self; 59969cc3cfbSmax }, 60069cc3cfbSmax py::name("__new__"), py::arg("cls"), py::arg("cast_from_value")); 60169cc3cfbSmax thisClass.attr("__new__") = newCf; 60269cc3cfbSmax 60369cc3cfbSmax // 'isinstance' method. 60469cc3cfbSmax def_staticmethod( 60569cc3cfbSmax "isinstance", 60669cc3cfbSmax [isaFunction](MlirValue other) { return isaFunction(other); }, 60769cc3cfbSmax py::arg("other_value")); 60869cc3cfbSmax } 60969cc3cfbSmax }; 61069cc3cfbSmax 611f13893f6SStella Laurenzo } // namespace adaptors 61291f11611SOleksandr "Alex" Zinenko 613f13893f6SStella Laurenzo } // namespace python 614f13893f6SStella Laurenzo } // namespace mlir 615f13893f6SStella Laurenzo 6168f23296bSMehdi Amini #endif // MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H 617