xref: /llvm-project/mlir/include/mlir/Bindings/Python/PybindAdaptors.h (revision b56d1ec6cb8b5cb3ff46cba39a1049ecf3831afb)
1 //===- PybindAdaptors.h - Interop with MLIR APIs via pybind11 -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file contains adaptors for clients of the core MLIR Python APIs to
9 // interop via MLIR CAPI types, using pybind11. The facilities here do not
10 // depend on implementation details of the MLIR Python API and do not introduce
11 // C++-level dependencies with it (requiring only Python and CAPI-level
12 // dependencies).
13 //
14 // It is encouraged to be used both in-tree and out-of-tree. For in-tree use
15 // cases, it should be used for dialect implementations (versus relying on
16 // Pybind-based internals of the core libraries).
17 //===----------------------------------------------------------------------===//
18 
19 #ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
20 #define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
21 
22 #include <pybind11/functional.h>
23 #include <pybind11/pybind11.h>
24 #include <pybind11/pytypes.h>
25 #include <pybind11/stl.h>
26 
27 #include "mlir-c/Bindings/Python/Interop.h"
28 #include "mlir-c/Diagnostics.h"
29 #include "mlir-c/IR.h"
30 
31 #include "llvm/ADT/Twine.h"
32 
33 namespace py = pybind11;
34 using namespace py::literals;
35 
36 // Raw CAPI type casters need to be declared before use, so always include them
37 // first.
38 namespace pybind11 {
39 namespace detail {
40 
41 /// Helper to convert a presumed MLIR API object to a capsule, accepting either
42 /// an explicit Capsule (which can happen when two C APIs are communicating
43 /// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR
44 /// attribute (through which supported MLIR Python API objects export their
45 /// contained API pointer as a capsule). Throws a type error if the object is
46 /// neither. This is intended to be used from type casters, which are invoked
47 /// with a raw handle (unowned). The returned object's lifetime may not extend
48 /// beyond the apiObject handle without explicitly having its refcount increased
49 /// (i.e. on return).
50 static py::object mlirApiObjectToCapsule(py::handle apiObject) {
51   if (PyCapsule_CheckExact(apiObject.ptr()))
52     return py::reinterpret_borrow<py::object>(apiObject);
53   if (!py::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) {
54     auto repr = py::repr(apiObject).cast<std::string>();
55     throw py::type_error(
56         (llvm::Twine("Expected an MLIR object (got ") + repr + ").").str());
57   }
58   return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
59 }
60 
61 // Note: Currently all of the following support cast from py::object to the
62 // Mlir* C-API type, but only a few light-weight, context-bound ones
63 // implicitly cast the other way because the use case has not yet emerged and
64 // ownership is unclear.
65 
66 /// Casts object <-> MlirAffineMap.
67 template <>
68 struct type_caster<MlirAffineMap> {
69   PYBIND11_TYPE_CASTER(MlirAffineMap, _("MlirAffineMap"));
70   bool load(handle src, bool) {
71     py::object capsule = mlirApiObjectToCapsule(src);
72     value = mlirPythonCapsuleToAffineMap(capsule.ptr());
73     if (mlirAffineMapIsNull(value)) {
74       return false;
75     }
76     return !mlirAffineMapIsNull(value);
77   }
78   static handle cast(MlirAffineMap v, return_value_policy, handle) {
79     py::object capsule =
80         py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(v));
81     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
82         .attr("AffineMap")
83         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
84         .release();
85   }
86 };
87 
88 /// Casts object <-> MlirAttribute.
89 template <>
90 struct type_caster<MlirAttribute> {
91   PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute"));
92   bool load(handle src, bool) {
93     py::object capsule = mlirApiObjectToCapsule(src);
94     value = mlirPythonCapsuleToAttribute(capsule.ptr());
95     return !mlirAttributeIsNull(value);
96   }
97   static handle cast(MlirAttribute v, return_value_policy, handle) {
98     py::object capsule =
99         py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(v));
100     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
101         .attr("Attribute")
102         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
103         .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
104         .release();
105   }
106 };
107 
108 /// Casts object -> MlirBlock.
109 template <>
110 struct type_caster<MlirBlock> {
111   PYBIND11_TYPE_CASTER(MlirBlock, _("MlirBlock"));
112   bool load(handle src, bool) {
113     py::object capsule = mlirApiObjectToCapsule(src);
114     value = mlirPythonCapsuleToBlock(capsule.ptr());
115     return !mlirBlockIsNull(value);
116   }
117 };
118 
119 /// Casts object -> MlirContext.
120 template <>
121 struct type_caster<MlirContext> {
122   PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext"));
123   bool load(handle src, bool) {
124     if (src.is_none()) {
125       // Gets the current thread-bound context.
126       // TODO: This raises an error of "No current context" currently.
127       // Update the implementation to pretty-print the helpful error that the
128       // core implementations print in this case.
129       src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
130                 .attr("Context")
131                 .attr("current");
132     }
133     py::object capsule = mlirApiObjectToCapsule(src);
134     value = mlirPythonCapsuleToContext(capsule.ptr());
135     return !mlirContextIsNull(value);
136   }
137 };
138 
139 /// Casts object <-> MlirDialectRegistry.
140 template <>
141 struct type_caster<MlirDialectRegistry> {
142   PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry"));
143   bool load(handle src, bool) {
144     py::object capsule = mlirApiObjectToCapsule(src);
145     value = mlirPythonCapsuleToDialectRegistry(capsule.ptr());
146     return !mlirDialectRegistryIsNull(value);
147   }
148   static handle cast(MlirDialectRegistry v, return_value_policy, handle) {
149     py::object capsule = py::reinterpret_steal<py::object>(
150         mlirPythonDialectRegistryToCapsule(v));
151     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
152         .attr("DialectRegistry")
153         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
154         .release();
155   }
156 };
157 
158 /// Casts object <-> MlirLocation.
159 template <>
160 struct type_caster<MlirLocation> {
161   PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation"));
162   bool load(handle src, bool) {
163     if (src.is_none()) {
164       // Gets the current thread-bound context.
165       src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
166                 .attr("Location")
167                 .attr("current");
168     }
169     py::object capsule = mlirApiObjectToCapsule(src);
170     value = mlirPythonCapsuleToLocation(capsule.ptr());
171     return !mlirLocationIsNull(value);
172   }
173   static handle cast(MlirLocation v, return_value_policy, handle) {
174     py::object capsule =
175         py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(v));
176     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
177         .attr("Location")
178         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
179         .release();
180   }
181 };
182 
183 /// Casts object <-> MlirModule.
184 template <>
185 struct type_caster<MlirModule> {
186   PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule"));
187   bool load(handle src, bool) {
188     py::object capsule = mlirApiObjectToCapsule(src);
189     value = mlirPythonCapsuleToModule(capsule.ptr());
190     return !mlirModuleIsNull(value);
191   }
192   static handle cast(MlirModule v, return_value_policy, handle) {
193     py::object capsule =
194         py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(v));
195     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
196         .attr("Module")
197         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
198         .release();
199   };
200 };
201 
202 /// Casts object <-> MlirFrozenRewritePatternSet.
203 template <>
204 struct type_caster<MlirFrozenRewritePatternSet> {
205   PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet,
206                        _("MlirFrozenRewritePatternSet"));
207   bool load(handle src, bool) {
208     py::object capsule = mlirApiObjectToCapsule(src);
209     value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
210     return value.ptr != nullptr;
211   }
212   static handle cast(MlirFrozenRewritePatternSet v, return_value_policy,
213                      handle) {
214     py::object capsule = py::reinterpret_steal<py::object>(
215         mlirPythonFrozenRewritePatternSetToCapsule(v));
216     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite"))
217         .attr("FrozenRewritePatternSet")
218         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
219         .release();
220   };
221 };
222 
223 /// Casts object <-> MlirOperation.
224 template <>
225 struct type_caster<MlirOperation> {
226   PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation"));
227   bool load(handle src, bool) {
228     py::object capsule = mlirApiObjectToCapsule(src);
229     value = mlirPythonCapsuleToOperation(capsule.ptr());
230     return !mlirOperationIsNull(value);
231   }
232   static handle cast(MlirOperation v, return_value_policy, handle) {
233     if (v.ptr == nullptr)
234       return py::none();
235     py::object capsule =
236         py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(v));
237     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
238         .attr("Operation")
239         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
240         .release();
241   };
242 };
243 
244 /// Casts object <-> MlirValue.
245 template <>
246 struct type_caster<MlirValue> {
247   PYBIND11_TYPE_CASTER(MlirValue, _("MlirValue"));
248   bool load(handle src, bool) {
249     py::object capsule = mlirApiObjectToCapsule(src);
250     value = mlirPythonCapsuleToValue(capsule.ptr());
251     return !mlirValueIsNull(value);
252   }
253   static handle cast(MlirValue v, return_value_policy, handle) {
254     if (v.ptr == nullptr)
255       return py::none();
256     py::object capsule =
257         py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(v));
258     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
259         .attr("Value")
260         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
261         .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
262         .release();
263   };
264 };
265 
266 /// Casts object -> MlirPassManager.
267 template <>
268 struct type_caster<MlirPassManager> {
269   PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"));
270   bool load(handle src, bool) {
271     py::object capsule = mlirApiObjectToCapsule(src);
272     value = mlirPythonCapsuleToPassManager(capsule.ptr());
273     return !mlirPassManagerIsNull(value);
274   }
275 };
276 
277 /// Casts object <-> MlirTypeID.
278 template <>
279 struct type_caster<MlirTypeID> {
280   PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID"));
281   bool load(handle src, bool) {
282     py::object capsule = mlirApiObjectToCapsule(src);
283     value = mlirPythonCapsuleToTypeID(capsule.ptr());
284     return !mlirTypeIDIsNull(value);
285   }
286   static handle cast(MlirTypeID v, return_value_policy, handle) {
287     if (v.ptr == nullptr)
288       return py::none();
289     py::object capsule =
290         py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(v));
291     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
292         .attr("TypeID")
293         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
294         .release();
295   };
296 };
297 
298 /// Casts object <-> MlirType.
299 template <>
300 struct type_caster<MlirType> {
301   PYBIND11_TYPE_CASTER(MlirType, _("MlirType"));
302   bool load(handle src, bool) {
303     py::object capsule = mlirApiObjectToCapsule(src);
304     value = mlirPythonCapsuleToType(capsule.ptr());
305     return !mlirTypeIsNull(value);
306   }
307   static handle cast(MlirType t, return_value_policy, handle) {
308     py::object capsule =
309         py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(t));
310     return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
311         .attr("Type")
312         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
313         .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
314         .release();
315   }
316 };
317 
318 } // namespace detail
319 } // namespace pybind11
320 
321 namespace mlir {
322 namespace python {
323 namespace adaptors {
324 
325 /// Provides a facility like py::class_ for defining a new class in a scope,
326 /// but this allows extension of an arbitrary Python class, defining methods
327 /// on it is a similar way. Classes defined in this way are very similar to
328 /// if defined in Python in the usual way but use Pybind11 machinery to do
329 /// it. These are not "real" Pybind11 classes but pure Python classes with no
330 /// relation to a concrete C++ class.
331 ///
332 /// Derived from a discussion upstream:
333 ///   https://github.com/pybind/pybind11/issues/1193
334 ///   (plus a fair amount of extra curricular poking)
335 ///   TODO: If this proves useful, see about including it in pybind11.
336 class pure_subclass {
337 public:
338   pure_subclass(py::handle scope, const char *derivedClassName,
339                 const py::object &superClass) {
340     py::object pyType =
341         py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
342     py::object metaclass = pyType(superClass);
343     py::dict attributes;
344 
345     thisClass =
346         metaclass(derivedClassName, py::make_tuple(superClass), attributes);
347     scope.attr(derivedClassName) = thisClass;
348   }
349 
350   template <typename Func, typename... Extra>
351   pure_subclass &def(const char *name, Func &&f, const Extra &...extra) {
352     py::cpp_function cf(
353         std::forward<Func>(f), py::name(name), py::is_method(thisClass),
354         py::sibling(py::getattr(thisClass, name, py::none())), extra...);
355     thisClass.attr(cf.name()) = cf;
356     return *this;
357   }
358 
359   template <typename Func, typename... Extra>
360   pure_subclass &def_property_readonly(const char *name, Func &&f,
361                                        const Extra &...extra) {
362     py::cpp_function cf(
363         std::forward<Func>(f), py::name(name), py::is_method(thisClass),
364         py::sibling(py::getattr(thisClass, name, py::none())), extra...);
365     auto builtinProperty =
366         py::reinterpret_borrow<py::object>((PyObject *)&PyProperty_Type);
367     thisClass.attr(name) = builtinProperty(cf);
368     return *this;
369   }
370 
371   template <typename Func, typename... Extra>
372   pure_subclass &def_staticmethod(const char *name, Func &&f,
373                                   const Extra &...extra) {
374     static_assert(!std::is_member_function_pointer<Func>::value,
375                   "def_staticmethod(...) called with a non-static member "
376                   "function pointer");
377     py::cpp_function cf(std::forward<Func>(f), py::name(name),
378                         py::scope(thisClass), extra...);
379     thisClass.attr(cf.name()) = py::staticmethod(cf);
380     return *this;
381   }
382 
383   template <typename Func, typename... Extra>
384   pure_subclass &def_classmethod(const char *name, Func &&f,
385                                  const Extra &...extra) {
386     static_assert(!std::is_member_function_pointer<Func>::value,
387                   "def_classmethod(...) called with a non-static member "
388                   "function pointer");
389     py::cpp_function cf(std::forward<Func>(f), py::name(name),
390                         py::scope(thisClass), extra...);
391     thisClass.attr(cf.name()) =
392         py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
393     return *this;
394   }
395 
396   py::object get_class() const { return thisClass; }
397 
398 protected:
399   py::object superClass;
400   py::object thisClass;
401 };
402 
403 /// Creates a custom subclass of mlir.ir.Attribute, implementing a casting
404 /// constructor and type checking methods.
405 class mlir_attribute_subclass : public pure_subclass {
406 public:
407   using IsAFunctionTy = bool (*)(MlirAttribute);
408   using GetTypeIDFunctionTy = MlirTypeID (*)();
409 
410   /// Subclasses by looking up the super-class dynamically.
411   mlir_attribute_subclass(py::handle scope, const char *attrClassName,
412                           IsAFunctionTy isaFunction,
413                           GetTypeIDFunctionTy getTypeIDFunction = nullptr)
414       : mlir_attribute_subclass(
415             scope, attrClassName, isaFunction,
416             py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
417                 .attr("Attribute"),
418             getTypeIDFunction) {}
419 
420   /// Subclasses with a provided mlir.ir.Attribute super-class. This must
421   /// be used if the subclass is being defined in the same extension module
422   /// as the mlir.ir class (otherwise, it will trigger a recursive
423   /// initialization).
424   mlir_attribute_subclass(py::handle scope, const char *typeClassName,
425                           IsAFunctionTy isaFunction, const py::object &superCls,
426                           GetTypeIDFunctionTy getTypeIDFunction = nullptr)
427       : pure_subclass(scope, typeClassName, superCls) {
428     // Casting constructor. Note that it hard, if not impossible, to properly
429     // call chain to parent `__init__` in pybind11 due to its special handling
430     // for init functions that don't have a fully constructed self-reference,
431     // which makes it impossible to forward it to `__init__` of a superclass.
432     // Instead, provide a custom `__new__` and call that of a superclass, which
433     // eventually calls `__init__` of the superclass. Since attribute subclasses
434     // have no additional members, we can just return the instance thus created
435     // without amending it.
436     std::string captureTypeName(
437         typeClassName); // As string in case if typeClassName is not static.
438     py::cpp_function newCf(
439         [superCls, isaFunction, captureTypeName](py::object cls,
440                                                  py::object otherAttribute) {
441           MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherAttribute);
442           if (!isaFunction(rawAttribute)) {
443             auto origRepr = py::repr(otherAttribute).cast<std::string>();
444             throw std::invalid_argument(
445                 (llvm::Twine("Cannot cast attribute to ") + captureTypeName +
446                  " (from " + origRepr + ")")
447                     .str());
448           }
449           py::object self = superCls.attr("__new__")(cls, otherAttribute);
450           return self;
451         },
452         py::name("__new__"), py::arg("cls"), py::arg("cast_from_attr"));
453     thisClass.attr("__new__") = newCf;
454 
455     // 'isinstance' method.
456     def_staticmethod(
457         "isinstance",
458         [isaFunction](MlirAttribute other) { return isaFunction(other); },
459         py::arg("other_attribute"));
460     def("__repr__", [superCls, captureTypeName](py::object self) {
461       return py::repr(superCls(self))
462           .attr("replace")(superCls.attr("__name__"), captureTypeName);
463     });
464     if (getTypeIDFunction) {
465       def_staticmethod("get_static_typeid",
466                        [getTypeIDFunction]() { return getTypeIDFunction(); });
467       py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
468           .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
469               getTypeIDFunction())(pybind11::cpp_function(
470               [thisClass = thisClass](const py::object &mlirAttribute) {
471                 return thisClass(mlirAttribute);
472               }));
473     }
474   }
475 };
476 
477 /// Creates a custom subclass of mlir.ir.Type, implementing a casting
478 /// constructor and type checking methods.
479 class mlir_type_subclass : public pure_subclass {
480 public:
481   using IsAFunctionTy = bool (*)(MlirType);
482   using GetTypeIDFunctionTy = MlirTypeID (*)();
483 
484   /// Subclasses by looking up the super-class dynamically.
485   mlir_type_subclass(py::handle scope, const char *typeClassName,
486                      IsAFunctionTy isaFunction,
487                      GetTypeIDFunctionTy getTypeIDFunction = nullptr)
488       : mlir_type_subclass(
489             scope, typeClassName, isaFunction,
490             py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type"),
491             getTypeIDFunction) {}
492 
493   /// Subclasses with a provided mlir.ir.Type super-class. This must
494   /// be used if the subclass is being defined in the same extension module
495   /// as the mlir.ir class (otherwise, it will trigger a recursive
496   /// initialization).
497   mlir_type_subclass(py::handle scope, const char *typeClassName,
498                      IsAFunctionTy isaFunction, const py::object &superCls,
499                      GetTypeIDFunctionTy getTypeIDFunction = nullptr)
500       : pure_subclass(scope, typeClassName, superCls) {
501     // Casting constructor. Note that it hard, if not impossible, to properly
502     // call chain to parent `__init__` in pybind11 due to its special handling
503     // for init functions that don't have a fully constructed self-reference,
504     // which makes it impossible to forward it to `__init__` of a superclass.
505     // Instead, provide a custom `__new__` and call that of a superclass, which
506     // eventually calls `__init__` of the superclass. Since attribute subclasses
507     // have no additional members, we can just return the instance thus created
508     // without amending it.
509     std::string captureTypeName(
510         typeClassName); // As string in case if typeClassName is not static.
511     py::cpp_function newCf(
512         [superCls, isaFunction, captureTypeName](py::object cls,
513                                                  py::object otherType) {
514           MlirType rawType = py::cast<MlirType>(otherType);
515           if (!isaFunction(rawType)) {
516             auto origRepr = py::repr(otherType).cast<std::string>();
517             throw std::invalid_argument((llvm::Twine("Cannot cast type to ") +
518                                          captureTypeName + " (from " +
519                                          origRepr + ")")
520                                             .str());
521           }
522           py::object self = superCls.attr("__new__")(cls, otherType);
523           return self;
524         },
525         py::name("__new__"), py::arg("cls"), py::arg("cast_from_type"));
526     thisClass.attr("__new__") = newCf;
527 
528     // 'isinstance' method.
529     def_staticmethod(
530         "isinstance",
531         [isaFunction](MlirType other) { return isaFunction(other); },
532         py::arg("other_type"));
533     def("__repr__", [superCls, captureTypeName](py::object self) {
534       return py::repr(superCls(self))
535           .attr("replace")(superCls.attr("__name__"), captureTypeName);
536     });
537     if (getTypeIDFunction) {
538       // 'get_static_typeid' method.
539       // This is modeled as a static method instead of a static property because
540       // `def_property_readonly_static` is not available in `pure_subclass` and
541       // we do not want to introduce the complexity that pybind uses to
542       // implement it.
543       def_staticmethod("get_static_typeid",
544                        [getTypeIDFunction]() { return getTypeIDFunction(); });
545       py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
546           .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
547               getTypeIDFunction())(pybind11::cpp_function(
548               [thisClass = thisClass](const py::object &mlirType) {
549                 return thisClass(mlirType);
550               }));
551     }
552   }
553 };
554 
555 /// Creates a custom subclass of mlir.ir.Value, implementing a casting
556 /// constructor and type checking methods.
557 class mlir_value_subclass : public pure_subclass {
558 public:
559   using IsAFunctionTy = bool (*)(MlirValue);
560 
561   /// Subclasses by looking up the super-class dynamically.
562   mlir_value_subclass(py::handle scope, const char *valueClassName,
563                       IsAFunctionTy isaFunction)
564       : mlir_value_subclass(
565             scope, valueClassName, isaFunction,
566             py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Value")) {
567   }
568 
569   /// Subclasses with a provided mlir.ir.Value super-class. This must
570   /// be used if the subclass is being defined in the same extension module
571   /// as the mlir.ir class (otherwise, it will trigger a recursive
572   /// initialization).
573   mlir_value_subclass(py::handle scope, const char *valueClassName,
574                       IsAFunctionTy isaFunction, const py::object &superCls)
575       : pure_subclass(scope, valueClassName, superCls) {
576     // Casting constructor. Note that it hard, if not impossible, to properly
577     // call chain to parent `__init__` in pybind11 due to its special handling
578     // for init functions that don't have a fully constructed self-reference,
579     // which makes it impossible to forward it to `__init__` of a superclass.
580     // Instead, provide a custom `__new__` and call that of a superclass, which
581     // eventually calls `__init__` of the superclass. Since attribute subclasses
582     // have no additional members, we can just return the instance thus created
583     // without amending it.
584     std::string captureValueName(
585         valueClassName); // As string in case if valueClassName is not static.
586     py::cpp_function newCf(
587         [superCls, isaFunction, captureValueName](py::object cls,
588                                                   py::object otherValue) {
589           MlirValue rawValue = py::cast<MlirValue>(otherValue);
590           if (!isaFunction(rawValue)) {
591             auto origRepr = py::repr(otherValue).cast<std::string>();
592             throw std::invalid_argument((llvm::Twine("Cannot cast value to ") +
593                                          captureValueName + " (from " +
594                                          origRepr + ")")
595                                             .str());
596           }
597           py::object self = superCls.attr("__new__")(cls, otherValue);
598           return self;
599         },
600         py::name("__new__"), py::arg("cls"), py::arg("cast_from_value"));
601     thisClass.attr("__new__") = newCf;
602 
603     // 'isinstance' method.
604     def_staticmethod(
605         "isinstance",
606         [isaFunction](MlirValue other) { return isaFunction(other); },
607         py::arg("other_value"));
608   }
609 };
610 
611 } // namespace adaptors
612 
613 } // namespace python
614 } // namespace mlir
615 
616 #endif // MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
617