xref: /llvm-project/mlir/include/mlir/Bindings/Python/PybindAdaptors.h (revision b56d1ec6cb8b5cb3ff46cba39a1049ecf3831afb)
1392622d0SMaksim Levental //===- PybindAdaptors.h - Interop with MLIR APIs via pybind11 -------------===//
2f13893f6SStella Laurenzo //
3f13893f6SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f13893f6SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5f13893f6SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f13893f6SStella Laurenzo //
7f13893f6SStella Laurenzo //===----------------------------------------------------------------------===//
8f13893f6SStella Laurenzo // This file contains adaptors for clients of the core MLIR Python APIs to
9392622d0SMaksim Levental // interop via MLIR CAPI types, using pybind11. The facilities here do not
10392622d0SMaksim Levental // depend on implementation details of the MLIR Python API and do not introduce
11392622d0SMaksim Levental // C++-level dependencies with it (requiring only Python and CAPI-level
12392622d0SMaksim Levental // dependencies).
13f13893f6SStella Laurenzo //
14f13893f6SStella Laurenzo // It is encouraged to be used both in-tree and out-of-tree. For in-tree use
15f13893f6SStella Laurenzo // cases, it should be used for dialect implementations (versus relying on
16f13893f6SStella Laurenzo // Pybind-based internals of the core libraries).
17f13893f6SStella Laurenzo //===----------------------------------------------------------------------===//
18f13893f6SStella Laurenzo 
198f23296bSMehdi Amini #ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
208f23296bSMehdi Amini #define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
21f13893f6SStella Laurenzo 
2247148832SHideto Ueno #include <pybind11/functional.h>
23f13893f6SStella Laurenzo #include <pybind11/pybind11.h>
24f13893f6SStella Laurenzo #include <pybind11/pytypes.h>
25f13893f6SStella Laurenzo #include <pybind11/stl.h>
26f13893f6SStella Laurenzo 
27f13893f6SStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h"
2891f11611SOleksandr "Alex" Zinenko #include "mlir-c/Diagnostics.h"
29f13893f6SStella Laurenzo #include "mlir-c/IR.h"
30f13893f6SStella Laurenzo 
31f13893f6SStella Laurenzo #include "llvm/ADT/Twine.h"
32f13893f6SStella Laurenzo 
33f13893f6SStella Laurenzo namespace py = pybind11;
34bfb1ba75Smax using namespace py::literals;
35f13893f6SStella Laurenzo 
36f13893f6SStella Laurenzo // Raw CAPI type casters need to be declared before use, so always include them
37f13893f6SStella Laurenzo // first.
38f13893f6SStella Laurenzo namespace pybind11 {
39f13893f6SStella Laurenzo namespace detail {
40f13893f6SStella Laurenzo 
41f13893f6SStella Laurenzo /// Helper to convert a presumed MLIR API object to a capsule, accepting either
42f13893f6SStella Laurenzo /// an explicit Capsule (which can happen when two C APIs are communicating
43f13893f6SStella Laurenzo /// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR
44f13893f6SStella Laurenzo /// attribute (through which supported MLIR Python API objects export their
4522fea18eSAlex Zinenko /// contained API pointer as a capsule). Throws a type error if the object is
4622fea18eSAlex Zinenko /// neither. This is intended to be used from type casters, which are invoked
4722fea18eSAlex Zinenko /// with a raw handle (unowned). The returned object's lifetime may not extend
4822fea18eSAlex Zinenko /// beyond the apiObject handle without explicitly having its refcount increased
4922fea18eSAlex Zinenko /// (i.e. on return).
50f13893f6SStella Laurenzo static py::object mlirApiObjectToCapsule(py::handle apiObject) {
51f13893f6SStella Laurenzo   if (PyCapsule_CheckExact(apiObject.ptr()))
52f13893f6SStella Laurenzo     return py::reinterpret_borrow<py::object>(apiObject);
5322fea18eSAlex Zinenko   if (!py::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) {
5422fea18eSAlex Zinenko     auto repr = py::repr(apiObject).cast<std::string>();
5522fea18eSAlex Zinenko     throw py::type_error(
5622fea18eSAlex Zinenko         (llvm::Twine("Expected an MLIR object (got ") + repr + ").").str());
5722fea18eSAlex Zinenko   }
58f13893f6SStella Laurenzo   return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
59f13893f6SStella Laurenzo }
60f13893f6SStella Laurenzo 
61f13893f6SStella Laurenzo // Note: Currently all of the following support cast from py::object to the
62f13893f6SStella Laurenzo // Mlir* C-API type, but only a few light-weight, context-bound ones
63f13893f6SStella Laurenzo // implicitly cast the other way because the use case has not yet emerged and
64f13893f6SStella Laurenzo // ownership is unclear.
65f13893f6SStella Laurenzo 
66f13893f6SStella Laurenzo /// Casts object <-> MlirAffineMap.
67f13893f6SStella Laurenzo template <>
68f13893f6SStella Laurenzo struct type_caster<MlirAffineMap> {
69f13893f6SStella Laurenzo   PYBIND11_TYPE_CASTER(MlirAffineMap, _("MlirAffineMap"));
70f13893f6SStella Laurenzo   bool load(handle src, bool) {
71f13893f6SStella Laurenzo     py::object capsule = mlirApiObjectToCapsule(src);
72f13893f6SStella Laurenzo     value = mlirPythonCapsuleToAffineMap(capsule.ptr());
73f13893f6SStella Laurenzo     if (mlirAffineMapIsNull(value)) {
74f13893f6SStella Laurenzo       return false;
75f13893f6SStella Laurenzo     }
76f13893f6SStella Laurenzo     return !mlirAffineMapIsNull(value);
77f13893f6SStella Laurenzo   }
78f13893f6SStella Laurenzo   static handle cast(MlirAffineMap v, return_value_policy, handle) {
79f13893f6SStella Laurenzo     py::object capsule =
80f13893f6SStella Laurenzo         py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(v));
81e78b745cSStella Laurenzo     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
82f13893f6SStella Laurenzo         .attr("AffineMap")
83f13893f6SStella Laurenzo         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
84f13893f6SStella Laurenzo         .release();
85f13893f6SStella Laurenzo   }
86f13893f6SStella Laurenzo };
87f13893f6SStella Laurenzo 
88f13893f6SStella Laurenzo /// Casts object <-> MlirAttribute.
89f13893f6SStella Laurenzo template <>
90f13893f6SStella Laurenzo struct type_caster<MlirAttribute> {
91f13893f6SStella Laurenzo   PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute"));
92f13893f6SStella Laurenzo   bool load(handle src, bool) {
93f13893f6SStella Laurenzo     py::object capsule = mlirApiObjectToCapsule(src);
94f13893f6SStella Laurenzo     value = mlirPythonCapsuleToAttribute(capsule.ptr());
9558ec17cbSMehdi Amini     return !mlirAttributeIsNull(value);
96f13893f6SStella Laurenzo   }
97f13893f6SStella Laurenzo   static handle cast(MlirAttribute v, return_value_policy, handle) {
98f13893f6SStella Laurenzo     py::object capsule =
99f13893f6SStella Laurenzo         py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(v));
100e78b745cSStella Laurenzo     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
101f13893f6SStella Laurenzo         .attr("Attribute")
102f13893f6SStella Laurenzo         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
1039566ee28Smax         .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
104f13893f6SStella Laurenzo         .release();
105f13893f6SStella Laurenzo   }
106f13893f6SStella Laurenzo };
107f13893f6SStella Laurenzo 
108c83318e3SAdam Paszke /// Casts object -> MlirBlock.
109c83318e3SAdam Paszke template <>
110c83318e3SAdam Paszke struct type_caster<MlirBlock> {
111c83318e3SAdam Paszke   PYBIND11_TYPE_CASTER(MlirBlock, _("MlirBlock"));
112c83318e3SAdam Paszke   bool load(handle src, bool) {
113c83318e3SAdam Paszke     py::object capsule = mlirApiObjectToCapsule(src);
114c83318e3SAdam Paszke     value = mlirPythonCapsuleToBlock(capsule.ptr());
115c83318e3SAdam Paszke     return !mlirBlockIsNull(value);
116c83318e3SAdam Paszke   }
117c83318e3SAdam Paszke };
118c83318e3SAdam Paszke 
119f13893f6SStella Laurenzo /// Casts object -> MlirContext.
120f13893f6SStella Laurenzo template <>
121f13893f6SStella Laurenzo struct type_caster<MlirContext> {
122f13893f6SStella Laurenzo   PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext"));
123f13893f6SStella Laurenzo   bool load(handle src, bool) {
124f13893f6SStella Laurenzo     if (src.is_none()) {
125f13893f6SStella Laurenzo       // Gets the current thread-bound context.
126f13893f6SStella Laurenzo       // TODO: This raises an error of "No current context" currently.
127f13893f6SStella Laurenzo       // Update the implementation to pretty-print the helpful error that the
128f13893f6SStella Laurenzo       // core implementations print in this case.
129e78b745cSStella Laurenzo       src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
130f13893f6SStella Laurenzo                 .attr("Context")
131f13893f6SStella Laurenzo                 .attr("current");
132f13893f6SStella Laurenzo     }
133f13893f6SStella Laurenzo     py::object capsule = mlirApiObjectToCapsule(src);
134f13893f6SStella Laurenzo     value = mlirPythonCapsuleToContext(capsule.ptr());
13558ec17cbSMehdi Amini     return !mlirContextIsNull(value);
136f13893f6SStella Laurenzo   }
137f13893f6SStella Laurenzo };
138f13893f6SStella Laurenzo 
1395e83a5b4SStella Laurenzo /// Casts object <-> MlirDialectRegistry.
1405e83a5b4SStella Laurenzo template <>
1415e83a5b4SStella Laurenzo struct type_caster<MlirDialectRegistry> {
1425e83a5b4SStella Laurenzo   PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry"));
1435e83a5b4SStella Laurenzo   bool load(handle src, bool) {
1445e83a5b4SStella Laurenzo     py::object capsule = mlirApiObjectToCapsule(src);
1455e83a5b4SStella Laurenzo     value = mlirPythonCapsuleToDialectRegistry(capsule.ptr());
1465e83a5b4SStella Laurenzo     return !mlirDialectRegistryIsNull(value);
1475e83a5b4SStella Laurenzo   }
1485e83a5b4SStella Laurenzo   static handle cast(MlirDialectRegistry v, return_value_policy, handle) {
1495e83a5b4SStella Laurenzo     py::object capsule = py::reinterpret_steal<py::object>(
1505e83a5b4SStella Laurenzo         mlirPythonDialectRegistryToCapsule(v));
1515e83a5b4SStella Laurenzo     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
1525e83a5b4SStella Laurenzo         .attr("DialectRegistry")
1535e83a5b4SStella Laurenzo         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
1545e83a5b4SStella Laurenzo         .release();
1555e83a5b4SStella Laurenzo   }
1565e83a5b4SStella Laurenzo };
1575e83a5b4SStella Laurenzo 
158f13893f6SStella Laurenzo /// Casts object <-> MlirLocation.
159f13893f6SStella Laurenzo template <>
160f13893f6SStella Laurenzo struct type_caster<MlirLocation> {
161f13893f6SStella Laurenzo   PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation"));
162f13893f6SStella Laurenzo   bool load(handle src, bool) {
16394d96c2aSJohn Demme     if (src.is_none()) {
16494d96c2aSJohn Demme       // Gets the current thread-bound context.
16594d96c2aSJohn Demme       src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
16694d96c2aSJohn Demme                 .attr("Location")
16794d96c2aSJohn Demme                 .attr("current");
16894d96c2aSJohn Demme     }
169f13893f6SStella Laurenzo     py::object capsule = mlirApiObjectToCapsule(src);
170f13893f6SStella Laurenzo     value = mlirPythonCapsuleToLocation(capsule.ptr());
17158ec17cbSMehdi Amini     return !mlirLocationIsNull(value);
172f13893f6SStella Laurenzo   }
173f13893f6SStella Laurenzo   static handle cast(MlirLocation v, return_value_policy, handle) {
174f13893f6SStella Laurenzo     py::object capsule =
175f13893f6SStella Laurenzo         py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(v));
176e78b745cSStella Laurenzo     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
177f13893f6SStella Laurenzo         .attr("Location")
178f13893f6SStella Laurenzo         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
179f13893f6SStella Laurenzo         .release();
180f13893f6SStella Laurenzo   }
181f13893f6SStella Laurenzo };
182f13893f6SStella Laurenzo 
183f13893f6SStella Laurenzo /// Casts object <-> MlirModule.
184f13893f6SStella Laurenzo template <>
185f13893f6SStella Laurenzo struct type_caster<MlirModule> {
186f13893f6SStella Laurenzo   PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule"));
187f13893f6SStella Laurenzo   bool load(handle src, bool) {
188f13893f6SStella Laurenzo     py::object capsule = mlirApiObjectToCapsule(src);
189f13893f6SStella Laurenzo     value = mlirPythonCapsuleToModule(capsule.ptr());
19058ec17cbSMehdi Amini     return !mlirModuleIsNull(value);
191f13893f6SStella Laurenzo   }
192f13893f6SStella Laurenzo   static handle cast(MlirModule v, return_value_policy, handle) {
193f13893f6SStella Laurenzo     py::object capsule =
194f13893f6SStella Laurenzo         py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(v));
195e78b745cSStella Laurenzo     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
196f13893f6SStella Laurenzo         .attr("Module")
197f13893f6SStella Laurenzo         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
198f13893f6SStella Laurenzo         .release();
199f13893f6SStella Laurenzo   };
200f13893f6SStella Laurenzo };
201f13893f6SStella Laurenzo 
20218cf1cd9SJacques Pienaar /// Casts object <-> MlirFrozenRewritePatternSet.
20318cf1cd9SJacques Pienaar template <>
20418cf1cd9SJacques Pienaar struct type_caster<MlirFrozenRewritePatternSet> {
20518cf1cd9SJacques Pienaar   PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet,
20618cf1cd9SJacques Pienaar                        _("MlirFrozenRewritePatternSet"));
20718cf1cd9SJacques Pienaar   bool load(handle src, bool) {
20818cf1cd9SJacques Pienaar     py::object capsule = mlirApiObjectToCapsule(src);
20918cf1cd9SJacques Pienaar     value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
21018cf1cd9SJacques Pienaar     return value.ptr != nullptr;
21118cf1cd9SJacques Pienaar   }
21218cf1cd9SJacques Pienaar   static handle cast(MlirFrozenRewritePatternSet v, return_value_policy,
21318cf1cd9SJacques Pienaar                      handle) {
21418cf1cd9SJacques Pienaar     py::object capsule = py::reinterpret_steal<py::object>(
21518cf1cd9SJacques Pienaar         mlirPythonFrozenRewritePatternSetToCapsule(v));
21618cf1cd9SJacques Pienaar     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite"))
21718cf1cd9SJacques Pienaar         .attr("FrozenRewritePatternSet")
21818cf1cd9SJacques Pienaar         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
21918cf1cd9SJacques Pienaar         .release();
22018cf1cd9SJacques Pienaar   };
22118cf1cd9SJacques Pienaar };
22218cf1cd9SJacques Pienaar 
223f13893f6SStella Laurenzo /// Casts object <-> MlirOperation.
224f13893f6SStella Laurenzo template <>
225f13893f6SStella Laurenzo struct type_caster<MlirOperation> {
226f13893f6SStella Laurenzo   PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation"));
227f13893f6SStella Laurenzo   bool load(handle src, bool) {
228f13893f6SStella Laurenzo     py::object capsule = mlirApiObjectToCapsule(src);
229f13893f6SStella Laurenzo     value = mlirPythonCapsuleToOperation(capsule.ptr());
23058ec17cbSMehdi Amini     return !mlirOperationIsNull(value);
231f13893f6SStella Laurenzo   }
232f13893f6SStella Laurenzo   static handle cast(MlirOperation v, return_value_policy, handle) {
233f13893f6SStella Laurenzo     if (v.ptr == nullptr)
234f13893f6SStella Laurenzo       return py::none();
235f13893f6SStella Laurenzo     py::object capsule =
236f13893f6SStella Laurenzo         py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(v));
237e78b745cSStella Laurenzo     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
238f13893f6SStella Laurenzo         .attr("Operation")
239f13893f6SStella Laurenzo         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
240f13893f6SStella Laurenzo         .release();
241f13893f6SStella Laurenzo   };
242f13893f6SStella Laurenzo };
243f13893f6SStella Laurenzo 
2445e9b6a22SJohn Demme /// Casts object <-> MlirValue.
2455e9b6a22SJohn Demme template <>
2465e9b6a22SJohn Demme struct type_caster<MlirValue> {
2475e9b6a22SJohn Demme   PYBIND11_TYPE_CASTER(MlirValue, _("MlirValue"));
2485e9b6a22SJohn Demme   bool load(handle src, bool) {
2495e9b6a22SJohn Demme     py::object capsule = mlirApiObjectToCapsule(src);
2505e9b6a22SJohn Demme     value = mlirPythonCapsuleToValue(capsule.ptr());
2515e9b6a22SJohn Demme     return !mlirValueIsNull(value);
2525e9b6a22SJohn Demme   }
2535e9b6a22SJohn Demme   static handle cast(MlirValue v, return_value_policy, handle) {
2545e9b6a22SJohn Demme     if (v.ptr == nullptr)
2555e9b6a22SJohn Demme       return py::none();
2565e9b6a22SJohn Demme     py::object capsule =
2575e9b6a22SJohn Demme         py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(v));
2585e9b6a22SJohn Demme     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
2595e9b6a22SJohn Demme         .attr("Value")
2605e9b6a22SJohn Demme         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
2617c850867SMaksim Levental         .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
2625e9b6a22SJohn Demme         .release();
2635e9b6a22SJohn Demme   };
2645e9b6a22SJohn Demme };
2655e9b6a22SJohn Demme 
266f13893f6SStella Laurenzo /// Casts object -> MlirPassManager.
267f13893f6SStella Laurenzo template <>
268f13893f6SStella Laurenzo struct type_caster<MlirPassManager> {
269f13893f6SStella Laurenzo   PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"));
270f13893f6SStella Laurenzo   bool load(handle src, bool) {
271f13893f6SStella Laurenzo     py::object capsule = mlirApiObjectToCapsule(src);
272f13893f6SStella Laurenzo     value = mlirPythonCapsuleToPassManager(capsule.ptr());
27358ec17cbSMehdi Amini     return !mlirPassManagerIsNull(value);
274f13893f6SStella Laurenzo   }
275f13893f6SStella Laurenzo };
276f13893f6SStella Laurenzo 
277d39a7844Smax /// Casts object <-> MlirTypeID.
278d39a7844Smax template <>
279d39a7844Smax struct type_caster<MlirTypeID> {
280d39a7844Smax   PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID"));
281d39a7844Smax   bool load(handle src, bool) {
282d39a7844Smax     py::object capsule = mlirApiObjectToCapsule(src);
283d39a7844Smax     value = mlirPythonCapsuleToTypeID(capsule.ptr());
284d39a7844Smax     return !mlirTypeIDIsNull(value);
285d39a7844Smax   }
286d39a7844Smax   static handle cast(MlirTypeID v, return_value_policy, handle) {
287d39a7844Smax     if (v.ptr == nullptr)
288d39a7844Smax       return py::none();
289d39a7844Smax     py::object capsule =
290d39a7844Smax         py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(v));
291d39a7844Smax     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
292d39a7844Smax         .attr("TypeID")
293d39a7844Smax         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
294d39a7844Smax         .release();
295d39a7844Smax   };
296d39a7844Smax };
297d39a7844Smax 
298f13893f6SStella Laurenzo /// Casts object <-> MlirType.
299f13893f6SStella Laurenzo template <>
300f13893f6SStella Laurenzo struct type_caster<MlirType> {
301f13893f6SStella Laurenzo   PYBIND11_TYPE_CASTER(MlirType, _("MlirType"));
302f13893f6SStella Laurenzo   bool load(handle src, bool) {
303f13893f6SStella Laurenzo     py::object capsule = mlirApiObjectToCapsule(src);
304f13893f6SStella Laurenzo     value = mlirPythonCapsuleToType(capsule.ptr());
30558ec17cbSMehdi Amini     return !mlirTypeIsNull(value);
306f13893f6SStella Laurenzo   }
307f13893f6SStella Laurenzo   static handle cast(MlirType t, return_value_policy, handle) {
308f13893f6SStella Laurenzo     py::object capsule =
309f13893f6SStella Laurenzo         py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(t));
310e78b745cSStella Laurenzo     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
311f13893f6SStella Laurenzo         .attr("Type")
312f13893f6SStella Laurenzo         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
313bfb1ba75Smax         .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
314f13893f6SStella Laurenzo         .release();
315f13893f6SStella Laurenzo   }
316f13893f6SStella Laurenzo };
317f13893f6SStella Laurenzo 
318f13893f6SStella Laurenzo } // namespace detail
319f13893f6SStella Laurenzo } // namespace pybind11
320f13893f6SStella Laurenzo 
321f13893f6SStella Laurenzo namespace mlir {
322f13893f6SStella Laurenzo namespace python {
323f13893f6SStella Laurenzo namespace adaptors {
324f13893f6SStella Laurenzo 
325f13893f6SStella Laurenzo /// Provides a facility like py::class_ for defining a new class in a scope,
326f13893f6SStella Laurenzo /// but this allows extension of an arbitrary Python class, defining methods
327f13893f6SStella Laurenzo /// on it is a similar way. Classes defined in this way are very similar to
328f13893f6SStella Laurenzo /// if defined in Python in the usual way but use Pybind11 machinery to do
329f13893f6SStella Laurenzo /// it. These are not "real" Pybind11 classes but pure Python classes with no
330f13893f6SStella Laurenzo /// relation to a concrete C++ class.
331f13893f6SStella Laurenzo ///
332f13893f6SStella Laurenzo /// Derived from a discussion upstream:
333f13893f6SStella Laurenzo ///   https://github.com/pybind/pybind11/issues/1193
334f13893f6SStella Laurenzo ///   (plus a fair amount of extra curricular poking)
335f13893f6SStella Laurenzo ///   TODO: If this proves useful, see about including it in pybind11.
336f13893f6SStella Laurenzo class pure_subclass {
337f13893f6SStella Laurenzo public:
338f13893f6SStella Laurenzo   pure_subclass(py::handle scope, const char *derivedClassName,
339e8d07395SMehdi Amini                 const py::object &superClass) {
340f13893f6SStella Laurenzo     py::object pyType =
341f13893f6SStella Laurenzo         py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
342f13893f6SStella Laurenzo     py::object metaclass = pyType(superClass);
343f13893f6SStella Laurenzo     py::dict attributes;
344f13893f6SStella Laurenzo 
345f13893f6SStella Laurenzo     thisClass =
346f13893f6SStella Laurenzo         metaclass(derivedClassName, py::make_tuple(superClass), attributes);
347f13893f6SStella Laurenzo     scope.attr(derivedClassName) = thisClass;
348f13893f6SStella Laurenzo   }
349f13893f6SStella Laurenzo 
350f13893f6SStella Laurenzo   template <typename Func, typename... Extra>
351f13893f6SStella Laurenzo   pure_subclass &def(const char *name, Func &&f, const Extra &...extra) {
352f13893f6SStella Laurenzo     py::cpp_function cf(
3538b734798SAlex Zinenko         std::forward<Func>(f), py::name(name), py::is_method(thisClass),
354f13893f6SStella Laurenzo         py::sibling(py::getattr(thisClass, name, py::none())), extra...);
355f13893f6SStella Laurenzo     thisClass.attr(cf.name()) = cf;
356f13893f6SStella Laurenzo     return *this;
357f13893f6SStella Laurenzo   }
358f13893f6SStella Laurenzo 
359f13893f6SStella Laurenzo   template <typename Func, typename... Extra>
360f13893f6SStella Laurenzo   pure_subclass &def_property_readonly(const char *name, Func &&f,
361f13893f6SStella Laurenzo                                        const Extra &...extra) {
362f13893f6SStella Laurenzo     py::cpp_function cf(
3638b734798SAlex Zinenko         std::forward<Func>(f), py::name(name), py::is_method(thisClass),
364f13893f6SStella Laurenzo         py::sibling(py::getattr(thisClass, name, py::none())), extra...);
365f13893f6SStella Laurenzo     auto builtinProperty =
366f13893f6SStella Laurenzo         py::reinterpret_borrow<py::object>((PyObject *)&PyProperty_Type);
367f13893f6SStella Laurenzo     thisClass.attr(name) = builtinProperty(cf);
368f13893f6SStella Laurenzo     return *this;
369f13893f6SStella Laurenzo   }
370f13893f6SStella Laurenzo 
371f13893f6SStella Laurenzo   template <typename Func, typename... Extra>
372f13893f6SStella Laurenzo   pure_subclass &def_staticmethod(const char *name, Func &&f,
373f13893f6SStella Laurenzo                                   const Extra &...extra) {
374f13893f6SStella Laurenzo     static_assert(!std::is_member_function_pointer<Func>::value,
375f13893f6SStella Laurenzo                   "def_staticmethod(...) called with a non-static member "
376f13893f6SStella Laurenzo                   "function pointer");
377*b56d1ec6SPeter Hawkins     py::cpp_function cf(std::forward<Func>(f), py::name(name),
378*b56d1ec6SPeter Hawkins                         py::scope(thisClass), extra...);
379f13893f6SStella Laurenzo     thisClass.attr(cf.name()) = py::staticmethod(cf);
380f13893f6SStella Laurenzo     return *this;
381f13893f6SStella Laurenzo   }
382f13893f6SStella Laurenzo 
383f13893f6SStella Laurenzo   template <typename Func, typename... Extra>
384f13893f6SStella Laurenzo   pure_subclass &def_classmethod(const char *name, Func &&f,
385f13893f6SStella Laurenzo                                  const Extra &...extra) {
386f13893f6SStella Laurenzo     static_assert(!std::is_member_function_pointer<Func>::value,
387f13893f6SStella Laurenzo                   "def_classmethod(...) called with a non-static member "
388f13893f6SStella Laurenzo                   "function pointer");
389*b56d1ec6SPeter Hawkins     py::cpp_function cf(std::forward<Func>(f), py::name(name),
390*b56d1ec6SPeter Hawkins                         py::scope(thisClass), extra...);
391f13893f6SStella Laurenzo     thisClass.attr(cf.name()) =
392f13893f6SStella Laurenzo         py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
393f13893f6SStella Laurenzo     return *this;
394f13893f6SStella Laurenzo   }
395f13893f6SStella Laurenzo 
39666d4090dSAlex Zinenko   py::object get_class() const { return thisClass; }
39766d4090dSAlex Zinenko 
398f13893f6SStella Laurenzo protected:
399f13893f6SStella Laurenzo   py::object superClass;
400f13893f6SStella Laurenzo   py::object thisClass;
401f13893f6SStella Laurenzo };
402f13893f6SStella Laurenzo 
403f13893f6SStella Laurenzo /// Creates a custom subclass of mlir.ir.Attribute, implementing a casting
404f13893f6SStella Laurenzo /// constructor and type checking methods.
405f13893f6SStella Laurenzo class mlir_attribute_subclass : public pure_subclass {
406f13893f6SStella Laurenzo public:
407f13893f6SStella Laurenzo   using IsAFunctionTy = bool (*)(MlirAttribute);
40893156458SMaksim Levental   using GetTypeIDFunctionTy = MlirTypeID (*)();
409f13893f6SStella Laurenzo 
410f13893f6SStella Laurenzo   /// Subclasses by looking up the super-class dynamically.
411f13893f6SStella Laurenzo   mlir_attribute_subclass(py::handle scope, const char *attrClassName,
41293156458SMaksim Levental                           IsAFunctionTy isaFunction,
41393156458SMaksim Levental                           GetTypeIDFunctionTy getTypeIDFunction = nullptr)
414f13893f6SStella Laurenzo       : mlir_attribute_subclass(
415f13893f6SStella Laurenzo             scope, attrClassName, isaFunction,
416e78b745cSStella Laurenzo             py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
41793156458SMaksim Levental                 .attr("Attribute"),
41893156458SMaksim Levental             getTypeIDFunction) {}
419f13893f6SStella Laurenzo 
420f13893f6SStella Laurenzo   /// Subclasses with a provided mlir.ir.Attribute super-class. This must
421f13893f6SStella Laurenzo   /// be used if the subclass is being defined in the same extension module
422f13893f6SStella Laurenzo   /// as the mlir.ir class (otherwise, it will trigger a recursive
423f13893f6SStella Laurenzo   /// initialization).
424f13893f6SStella Laurenzo   mlir_attribute_subclass(py::handle scope, const char *typeClassName,
42593156458SMaksim Levental                           IsAFunctionTy isaFunction, const py::object &superCls,
42693156458SMaksim Levental                           GetTypeIDFunctionTy getTypeIDFunction = nullptr)
42789a92fb3SAlex Zinenko       : pure_subclass(scope, typeClassName, superCls) {
42889a92fb3SAlex Zinenko     // Casting constructor. Note that it hard, if not impossible, to properly
42989a92fb3SAlex Zinenko     // call chain to parent `__init__` in pybind11 due to its special handling
43089a92fb3SAlex Zinenko     // for init functions that don't have a fully constructed self-reference,
43189a92fb3SAlex Zinenko     // which makes it impossible to forward it to `__init__` of a superclass.
43289a92fb3SAlex Zinenko     // Instead, provide a custom `__new__` and call that of a superclass, which
43389a92fb3SAlex Zinenko     // eventually calls `__init__` of the superclass. Since attribute subclasses
43489a92fb3SAlex Zinenko     // have no additional members, we can just return the instance thus created
43589a92fb3SAlex Zinenko     // without amending it.
436f13893f6SStella Laurenzo     std::string captureTypeName(
437f13893f6SStella Laurenzo         typeClassName); // As string in case if typeClassName is not static.
43889a92fb3SAlex Zinenko     py::cpp_function newCf(
43989a92fb3SAlex Zinenko         [superCls, isaFunction, captureTypeName](py::object cls,
44089a92fb3SAlex Zinenko                                                  py::object otherAttribute) {
44189a92fb3SAlex Zinenko           MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherAttribute);
442f13893f6SStella Laurenzo           if (!isaFunction(rawAttribute)) {
44389a92fb3SAlex Zinenko             auto origRepr = py::repr(otherAttribute).cast<std::string>();
444f13893f6SStella Laurenzo             throw std::invalid_argument(
445f13893f6SStella Laurenzo                 (llvm::Twine("Cannot cast attribute to ") + captureTypeName +
446f13893f6SStella Laurenzo                  " (from " + origRepr + ")")
447f13893f6SStella Laurenzo                     .str());
448f13893f6SStella Laurenzo           }
44989a92fb3SAlex Zinenko           py::object self = superCls.attr("__new__")(cls, otherAttribute);
45089a92fb3SAlex Zinenko           return self;
451f13893f6SStella Laurenzo         },
45289a92fb3SAlex Zinenko         py::name("__new__"), py::arg("cls"), py::arg("cast_from_attr"));
45389a92fb3SAlex Zinenko     thisClass.attr("__new__") = newCf;
454f13893f6SStella Laurenzo 
455f13893f6SStella Laurenzo     // 'isinstance' method.
456f13893f6SStella Laurenzo     def_staticmethod(
457f13893f6SStella Laurenzo         "isinstance",
458f13893f6SStella Laurenzo         [isaFunction](MlirAttribute other) { return isaFunction(other); },
459f13893f6SStella Laurenzo         py::arg("other_attribute"));
46093156458SMaksim Levental     def("__repr__", [superCls, captureTypeName](py::object self) {
46193156458SMaksim Levental       return py::repr(superCls(self))
46293156458SMaksim Levental           .attr("replace")(superCls.attr("__name__"), captureTypeName);
46393156458SMaksim Levental     });
46493156458SMaksim Levental     if (getTypeIDFunction) {
46593156458SMaksim Levental       def_staticmethod("get_static_typeid",
46693156458SMaksim Levental                        [getTypeIDFunction]() { return getTypeIDFunction(); });
46793156458SMaksim Levental       py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
46893156458SMaksim Levental           .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
46993156458SMaksim Levental               getTypeIDFunction())(pybind11::cpp_function(
47093156458SMaksim Levental               [thisClass = thisClass](const py::object &mlirAttribute) {
47193156458SMaksim Levental                 return thisClass(mlirAttribute);
47293156458SMaksim Levental               }));
47393156458SMaksim Levental     }
474f13893f6SStella Laurenzo   }
475f13893f6SStella Laurenzo };
476f13893f6SStella Laurenzo 
477f13893f6SStella Laurenzo /// Creates a custom subclass of mlir.ir.Type, implementing a casting
478f13893f6SStella Laurenzo /// constructor and type checking methods.
479f13893f6SStella Laurenzo class mlir_type_subclass : public pure_subclass {
480f13893f6SStella Laurenzo public:
481f13893f6SStella Laurenzo   using IsAFunctionTy = bool (*)(MlirType);
482bfb1ba75Smax   using GetTypeIDFunctionTy = MlirTypeID (*)();
483f13893f6SStella Laurenzo 
484f13893f6SStella Laurenzo   /// Subclasses by looking up the super-class dynamically.
485f13893f6SStella Laurenzo   mlir_type_subclass(py::handle scope, const char *typeClassName,
486bfb1ba75Smax                      IsAFunctionTy isaFunction,
487bfb1ba75Smax                      GetTypeIDFunctionTy getTypeIDFunction = nullptr)
488f13893f6SStella Laurenzo       : mlir_type_subclass(
489f13893f6SStella Laurenzo             scope, typeClassName, isaFunction,
490bfb1ba75Smax             py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type"),
491bfb1ba75Smax             getTypeIDFunction) {}
492f13893f6SStella Laurenzo 
493f13893f6SStella Laurenzo   /// Subclasses with a provided mlir.ir.Type super-class. This must
494f13893f6SStella Laurenzo   /// be used if the subclass is being defined in the same extension module
495f13893f6SStella Laurenzo   /// as the mlir.ir class (otherwise, it will trigger a recursive
496f13893f6SStella Laurenzo   /// initialization).
497f13893f6SStella Laurenzo   mlir_type_subclass(py::handle scope, const char *typeClassName,
498bfb1ba75Smax                      IsAFunctionTy isaFunction, const py::object &superCls,
499bfb1ba75Smax                      GetTypeIDFunctionTy getTypeIDFunction = nullptr)
50089a92fb3SAlex Zinenko       : pure_subclass(scope, typeClassName, superCls) {
50189a92fb3SAlex Zinenko     // Casting constructor. Note that it hard, if not impossible, to properly
50289a92fb3SAlex Zinenko     // call chain to parent `__init__` in pybind11 due to its special handling
50389a92fb3SAlex Zinenko     // for init functions that don't have a fully constructed self-reference,
50489a92fb3SAlex Zinenko     // which makes it impossible to forward it to `__init__` of a superclass.
50589a92fb3SAlex Zinenko     // Instead, provide a custom `__new__` and call that of a superclass, which
50689a92fb3SAlex Zinenko     // eventually calls `__init__` of the superclass. Since attribute subclasses
50789a92fb3SAlex Zinenko     // have no additional members, we can just return the instance thus created
50889a92fb3SAlex Zinenko     // without amending it.
509f13893f6SStella Laurenzo     std::string captureTypeName(
510f13893f6SStella Laurenzo         typeClassName); // As string in case if typeClassName is not static.
51189a92fb3SAlex Zinenko     py::cpp_function newCf(
51289a92fb3SAlex Zinenko         [superCls, isaFunction, captureTypeName](py::object cls,
5137b1ceee6SAlex Zinenko                                                  py::object otherType) {
514f13893f6SStella Laurenzo           MlirType rawType = py::cast<MlirType>(otherType);
515f13893f6SStella Laurenzo           if (!isaFunction(rawType)) {
516f13893f6SStella Laurenzo             auto origRepr = py::repr(otherType).cast<std::string>();
517f13893f6SStella Laurenzo             throw std::invalid_argument((llvm::Twine("Cannot cast type to ") +
518f13893f6SStella Laurenzo                                          captureTypeName + " (from " +
519f13893f6SStella Laurenzo                                          origRepr + ")")
520f13893f6SStella Laurenzo                                             .str());
521f13893f6SStella Laurenzo           }
52289a92fb3SAlex Zinenko           py::object self = superCls.attr("__new__")(cls, otherType);
52389a92fb3SAlex Zinenko           return self;
524f13893f6SStella Laurenzo         },
52589a92fb3SAlex Zinenko         py::name("__new__"), py::arg("cls"), py::arg("cast_from_type"));
52689a92fb3SAlex Zinenko     thisClass.attr("__new__") = newCf;
527f13893f6SStella Laurenzo 
528f13893f6SStella Laurenzo     // 'isinstance' method.
529f13893f6SStella Laurenzo     def_staticmethod(
530f13893f6SStella Laurenzo         "isinstance",
531f13893f6SStella Laurenzo         [isaFunction](MlirType other) { return isaFunction(other); },
532f13893f6SStella Laurenzo         py::arg("other_type"));
533bfb1ba75Smax     def("__repr__", [superCls, captureTypeName](py::object self) {
534bfb1ba75Smax       return py::repr(superCls(self))
535bfb1ba75Smax           .attr("replace")(superCls.attr("__name__"), captureTypeName);
536bfb1ba75Smax     });
537bfb1ba75Smax     if (getTypeIDFunction) {
538681eacc1Smartin-luecke       // 'get_static_typeid' method.
539681eacc1Smartin-luecke       // This is modeled as a static method instead of a static property because
540681eacc1Smartin-luecke       // `def_property_readonly_static` is not available in `pure_subclass` and
541681eacc1Smartin-luecke       // we do not want to introduce the complexity that pybind uses to
542681eacc1Smartin-luecke       // implement it.
543681eacc1Smartin-luecke       def_staticmethod("get_static_typeid",
544681eacc1Smartin-luecke                        [getTypeIDFunction]() { return getTypeIDFunction(); });
545bfb1ba75Smax       py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
546bfb1ba75Smax           .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
5477c850867SMaksim Levental               getTypeIDFunction())(pybind11::cpp_function(
548bfb1ba75Smax               [thisClass = thisClass](const py::object &mlirType) {
549bfb1ba75Smax                 return thisClass(mlirType);
550bfb1ba75Smax               }));
551bfb1ba75Smax     }
552f13893f6SStella Laurenzo   }
553f13893f6SStella Laurenzo };
554f13893f6SStella Laurenzo 
55569cc3cfbSmax /// Creates a custom subclass of mlir.ir.Value, implementing a casting
55669cc3cfbSmax /// constructor and type checking methods.
55769cc3cfbSmax class mlir_value_subclass : public pure_subclass {
55869cc3cfbSmax public:
55969cc3cfbSmax   using IsAFunctionTy = bool (*)(MlirValue);
56069cc3cfbSmax 
56169cc3cfbSmax   /// Subclasses by looking up the super-class dynamically.
56269cc3cfbSmax   mlir_value_subclass(py::handle scope, const char *valueClassName,
56369cc3cfbSmax                       IsAFunctionTy isaFunction)
56469cc3cfbSmax       : mlir_value_subclass(
56569cc3cfbSmax             scope, valueClassName, isaFunction,
56669cc3cfbSmax             py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Value")) {
56769cc3cfbSmax   }
56869cc3cfbSmax 
56969cc3cfbSmax   /// Subclasses with a provided mlir.ir.Value super-class. This must
57069cc3cfbSmax   /// be used if the subclass is being defined in the same extension module
57169cc3cfbSmax   /// as the mlir.ir class (otherwise, it will trigger a recursive
57269cc3cfbSmax   /// initialization).
57369cc3cfbSmax   mlir_value_subclass(py::handle scope, const char *valueClassName,
57469cc3cfbSmax                       IsAFunctionTy isaFunction, const py::object &superCls)
57569cc3cfbSmax       : pure_subclass(scope, valueClassName, superCls) {
57669cc3cfbSmax     // Casting constructor. Note that it hard, if not impossible, to properly
57769cc3cfbSmax     // call chain to parent `__init__` in pybind11 due to its special handling
57869cc3cfbSmax     // for init functions that don't have a fully constructed self-reference,
57969cc3cfbSmax     // which makes it impossible to forward it to `__init__` of a superclass.
58069cc3cfbSmax     // Instead, provide a custom `__new__` and call that of a superclass, which
58169cc3cfbSmax     // eventually calls `__init__` of the superclass. Since attribute subclasses
58269cc3cfbSmax     // have no additional members, we can just return the instance thus created
58369cc3cfbSmax     // without amending it.
58469cc3cfbSmax     std::string captureValueName(
58569cc3cfbSmax         valueClassName); // As string in case if valueClassName is not static.
58669cc3cfbSmax     py::cpp_function newCf(
58769cc3cfbSmax         [superCls, isaFunction, captureValueName](py::object cls,
58869cc3cfbSmax                                                   py::object otherValue) {
58969cc3cfbSmax           MlirValue rawValue = py::cast<MlirValue>(otherValue);
59069cc3cfbSmax           if (!isaFunction(rawValue)) {
59169cc3cfbSmax             auto origRepr = py::repr(otherValue).cast<std::string>();
59269cc3cfbSmax             throw std::invalid_argument((llvm::Twine("Cannot cast value to ") +
59369cc3cfbSmax                                          captureValueName + " (from " +
59469cc3cfbSmax                                          origRepr + ")")
59569cc3cfbSmax                                             .str());
59669cc3cfbSmax           }
59769cc3cfbSmax           py::object self = superCls.attr("__new__")(cls, otherValue);
59869cc3cfbSmax           return self;
59969cc3cfbSmax         },
60069cc3cfbSmax         py::name("__new__"), py::arg("cls"), py::arg("cast_from_value"));
60169cc3cfbSmax     thisClass.attr("__new__") = newCf;
60269cc3cfbSmax 
60369cc3cfbSmax     // 'isinstance' method.
60469cc3cfbSmax     def_staticmethod(
60569cc3cfbSmax         "isinstance",
60669cc3cfbSmax         [isaFunction](MlirValue other) { return isaFunction(other); },
60769cc3cfbSmax         py::arg("other_value"));
60869cc3cfbSmax   }
60969cc3cfbSmax };
61069cc3cfbSmax 
611f13893f6SStella Laurenzo } // namespace adaptors
61291f11611SOleksandr "Alex" Zinenko 
613f13893f6SStella Laurenzo } // namespace python
614f13893f6SStella Laurenzo } // namespace mlir
615f13893f6SStella Laurenzo 
6168f23296bSMehdi Amini #endif // MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
617