1 //===- PybindAdaptors.h - Interop with MLIR APIs via pybind11 -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This file contains adaptors for clients of the core MLIR Python APIs to 9 // interop via MLIR CAPI types, using pybind11. The facilities here do not 10 // depend on implementation details of the MLIR Python API and do not introduce 11 // C++-level dependencies with it (requiring only Python and CAPI-level 12 // dependencies). 13 // 14 // It is encouraged to be used both in-tree and out-of-tree. For in-tree use 15 // cases, it should be used for dialect implementations (versus relying on 16 // Pybind-based internals of the core libraries). 17 //===----------------------------------------------------------------------===// 18 19 #ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H 20 #define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H 21 22 #include <pybind11/functional.h> 23 #include <pybind11/pybind11.h> 24 #include <pybind11/pytypes.h> 25 #include <pybind11/stl.h> 26 27 #include "mlir-c/Bindings/Python/Interop.h" 28 #include "mlir-c/Diagnostics.h" 29 #include "mlir-c/IR.h" 30 31 #include "llvm/ADT/Twine.h" 32 33 namespace py = pybind11; 34 using namespace py::literals; 35 36 // Raw CAPI type casters need to be declared before use, so always include them 37 // first. 38 namespace pybind11 { 39 namespace detail { 40 41 /// Helper to convert a presumed MLIR API object to a capsule, accepting either 42 /// an explicit Capsule (which can happen when two C APIs are communicating 43 /// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR 44 /// attribute (through which supported MLIR Python API objects export their 45 /// contained API pointer as a capsule). Throws a type error if the object is 46 /// neither. This is intended to be used from type casters, which are invoked 47 /// with a raw handle (unowned). The returned object's lifetime may not extend 48 /// beyond the apiObject handle without explicitly having its refcount increased 49 /// (i.e. on return). 50 static py::object mlirApiObjectToCapsule(py::handle apiObject) { 51 if (PyCapsule_CheckExact(apiObject.ptr())) 52 return py::reinterpret_borrow<py::object>(apiObject); 53 if (!py::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) { 54 auto repr = py::repr(apiObject).cast<std::string>(); 55 throw py::type_error( 56 (llvm::Twine("Expected an MLIR object (got ") + repr + ").").str()); 57 } 58 return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR); 59 } 60 61 // Note: Currently all of the following support cast from py::object to the 62 // Mlir* C-API type, but only a few light-weight, context-bound ones 63 // implicitly cast the other way because the use case has not yet emerged and 64 // ownership is unclear. 65 66 /// Casts object <-> MlirAffineMap. 67 template <> 68 struct type_caster<MlirAffineMap> { 69 PYBIND11_TYPE_CASTER(MlirAffineMap, _("MlirAffineMap")); 70 bool load(handle src, bool) { 71 py::object capsule = mlirApiObjectToCapsule(src); 72 value = mlirPythonCapsuleToAffineMap(capsule.ptr()); 73 if (mlirAffineMapIsNull(value)) { 74 return false; 75 } 76 return !mlirAffineMapIsNull(value); 77 } 78 static handle cast(MlirAffineMap v, return_value_policy, handle) { 79 py::object capsule = 80 py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(v)); 81 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 82 .attr("AffineMap") 83 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 84 .release(); 85 } 86 }; 87 88 /// Casts object <-> MlirAttribute. 89 template <> 90 struct type_caster<MlirAttribute> { 91 PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute")); 92 bool load(handle src, bool) { 93 py::object capsule = mlirApiObjectToCapsule(src); 94 value = mlirPythonCapsuleToAttribute(capsule.ptr()); 95 return !mlirAttributeIsNull(value); 96 } 97 static handle cast(MlirAttribute v, return_value_policy, handle) { 98 py::object capsule = 99 py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(v)); 100 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 101 .attr("Attribute") 102 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 103 .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() 104 .release(); 105 } 106 }; 107 108 /// Casts object -> MlirBlock. 109 template <> 110 struct type_caster<MlirBlock> { 111 PYBIND11_TYPE_CASTER(MlirBlock, _("MlirBlock")); 112 bool load(handle src, bool) { 113 py::object capsule = mlirApiObjectToCapsule(src); 114 value = mlirPythonCapsuleToBlock(capsule.ptr()); 115 return !mlirBlockIsNull(value); 116 } 117 }; 118 119 /// Casts object -> MlirContext. 120 template <> 121 struct type_caster<MlirContext> { 122 PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext")); 123 bool load(handle src, bool) { 124 if (src.is_none()) { 125 // Gets the current thread-bound context. 126 // TODO: This raises an error of "No current context" currently. 127 // Update the implementation to pretty-print the helpful error that the 128 // core implementations print in this case. 129 src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 130 .attr("Context") 131 .attr("current"); 132 } 133 py::object capsule = mlirApiObjectToCapsule(src); 134 value = mlirPythonCapsuleToContext(capsule.ptr()); 135 return !mlirContextIsNull(value); 136 } 137 }; 138 139 /// Casts object <-> MlirDialectRegistry. 140 template <> 141 struct type_caster<MlirDialectRegistry> { 142 PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry")); 143 bool load(handle src, bool) { 144 py::object capsule = mlirApiObjectToCapsule(src); 145 value = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); 146 return !mlirDialectRegistryIsNull(value); 147 } 148 static handle cast(MlirDialectRegistry v, return_value_policy, handle) { 149 py::object capsule = py::reinterpret_steal<py::object>( 150 mlirPythonDialectRegistryToCapsule(v)); 151 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 152 .attr("DialectRegistry") 153 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 154 .release(); 155 } 156 }; 157 158 /// Casts object <-> MlirLocation. 159 template <> 160 struct type_caster<MlirLocation> { 161 PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation")); 162 bool load(handle src, bool) { 163 if (src.is_none()) { 164 // Gets the current thread-bound context. 165 src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 166 .attr("Location") 167 .attr("current"); 168 } 169 py::object capsule = mlirApiObjectToCapsule(src); 170 value = mlirPythonCapsuleToLocation(capsule.ptr()); 171 return !mlirLocationIsNull(value); 172 } 173 static handle cast(MlirLocation v, return_value_policy, handle) { 174 py::object capsule = 175 py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(v)); 176 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 177 .attr("Location") 178 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 179 .release(); 180 } 181 }; 182 183 /// Casts object <-> MlirModule. 184 template <> 185 struct type_caster<MlirModule> { 186 PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule")); 187 bool load(handle src, bool) { 188 py::object capsule = mlirApiObjectToCapsule(src); 189 value = mlirPythonCapsuleToModule(capsule.ptr()); 190 return !mlirModuleIsNull(value); 191 } 192 static handle cast(MlirModule v, return_value_policy, handle) { 193 py::object capsule = 194 py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(v)); 195 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 196 .attr("Module") 197 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 198 .release(); 199 }; 200 }; 201 202 /// Casts object <-> MlirFrozenRewritePatternSet. 203 template <> 204 struct type_caster<MlirFrozenRewritePatternSet> { 205 PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet, 206 _("MlirFrozenRewritePatternSet")); 207 bool load(handle src, bool) { 208 py::object capsule = mlirApiObjectToCapsule(src); 209 value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); 210 return value.ptr != nullptr; 211 } 212 static handle cast(MlirFrozenRewritePatternSet v, return_value_policy, 213 handle) { 214 py::object capsule = py::reinterpret_steal<py::object>( 215 mlirPythonFrozenRewritePatternSetToCapsule(v)); 216 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite")) 217 .attr("FrozenRewritePatternSet") 218 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 219 .release(); 220 }; 221 }; 222 223 /// Casts object <-> MlirOperation. 224 template <> 225 struct type_caster<MlirOperation> { 226 PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation")); 227 bool load(handle src, bool) { 228 py::object capsule = mlirApiObjectToCapsule(src); 229 value = mlirPythonCapsuleToOperation(capsule.ptr()); 230 return !mlirOperationIsNull(value); 231 } 232 static handle cast(MlirOperation v, return_value_policy, handle) { 233 if (v.ptr == nullptr) 234 return py::none(); 235 py::object capsule = 236 py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(v)); 237 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 238 .attr("Operation") 239 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 240 .release(); 241 }; 242 }; 243 244 /// Casts object <-> MlirValue. 245 template <> 246 struct type_caster<MlirValue> { 247 PYBIND11_TYPE_CASTER(MlirValue, _("MlirValue")); 248 bool load(handle src, bool) { 249 py::object capsule = mlirApiObjectToCapsule(src); 250 value = mlirPythonCapsuleToValue(capsule.ptr()); 251 return !mlirValueIsNull(value); 252 } 253 static handle cast(MlirValue v, return_value_policy, handle) { 254 if (v.ptr == nullptr) 255 return py::none(); 256 py::object capsule = 257 py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(v)); 258 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 259 .attr("Value") 260 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 261 .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() 262 .release(); 263 }; 264 }; 265 266 /// Casts object -> MlirPassManager. 267 template <> 268 struct type_caster<MlirPassManager> { 269 PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager")); 270 bool load(handle src, bool) { 271 py::object capsule = mlirApiObjectToCapsule(src); 272 value = mlirPythonCapsuleToPassManager(capsule.ptr()); 273 return !mlirPassManagerIsNull(value); 274 } 275 }; 276 277 /// Casts object <-> MlirTypeID. 278 template <> 279 struct type_caster<MlirTypeID> { 280 PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID")); 281 bool load(handle src, bool) { 282 py::object capsule = mlirApiObjectToCapsule(src); 283 value = mlirPythonCapsuleToTypeID(capsule.ptr()); 284 return !mlirTypeIDIsNull(value); 285 } 286 static handle cast(MlirTypeID v, return_value_policy, handle) { 287 if (v.ptr == nullptr) 288 return py::none(); 289 py::object capsule = 290 py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(v)); 291 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 292 .attr("TypeID") 293 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 294 .release(); 295 }; 296 }; 297 298 /// Casts object <-> MlirType. 299 template <> 300 struct type_caster<MlirType> { 301 PYBIND11_TYPE_CASTER(MlirType, _("MlirType")); 302 bool load(handle src, bool) { 303 py::object capsule = mlirApiObjectToCapsule(src); 304 value = mlirPythonCapsuleToType(capsule.ptr()); 305 return !mlirTypeIsNull(value); 306 } 307 static handle cast(MlirType t, return_value_policy, handle) { 308 py::object capsule = 309 py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(t)); 310 return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 311 .attr("Type") 312 .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) 313 .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() 314 .release(); 315 } 316 }; 317 318 } // namespace detail 319 } // namespace pybind11 320 321 namespace mlir { 322 namespace python { 323 namespace adaptors { 324 325 /// Provides a facility like py::class_ for defining a new class in a scope, 326 /// but this allows extension of an arbitrary Python class, defining methods 327 /// on it is a similar way. Classes defined in this way are very similar to 328 /// if defined in Python in the usual way but use Pybind11 machinery to do 329 /// it. These are not "real" Pybind11 classes but pure Python classes with no 330 /// relation to a concrete C++ class. 331 /// 332 /// Derived from a discussion upstream: 333 /// https://github.com/pybind/pybind11/issues/1193 334 /// (plus a fair amount of extra curricular poking) 335 /// TODO: If this proves useful, see about including it in pybind11. 336 class pure_subclass { 337 public: 338 pure_subclass(py::handle scope, const char *derivedClassName, 339 const py::object &superClass) { 340 py::object pyType = 341 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 342 py::object metaclass = pyType(superClass); 343 py::dict attributes; 344 345 thisClass = 346 metaclass(derivedClassName, py::make_tuple(superClass), attributes); 347 scope.attr(derivedClassName) = thisClass; 348 } 349 350 template <typename Func, typename... Extra> 351 pure_subclass &def(const char *name, Func &&f, const Extra &...extra) { 352 py::cpp_function cf( 353 std::forward<Func>(f), py::name(name), py::is_method(thisClass), 354 py::sibling(py::getattr(thisClass, name, py::none())), extra...); 355 thisClass.attr(cf.name()) = cf; 356 return *this; 357 } 358 359 template <typename Func, typename... Extra> 360 pure_subclass &def_property_readonly(const char *name, Func &&f, 361 const Extra &...extra) { 362 py::cpp_function cf( 363 std::forward<Func>(f), py::name(name), py::is_method(thisClass), 364 py::sibling(py::getattr(thisClass, name, py::none())), extra...); 365 auto builtinProperty = 366 py::reinterpret_borrow<py::object>((PyObject *)&PyProperty_Type); 367 thisClass.attr(name) = builtinProperty(cf); 368 return *this; 369 } 370 371 template <typename Func, typename... Extra> 372 pure_subclass &def_staticmethod(const char *name, Func &&f, 373 const Extra &...extra) { 374 static_assert(!std::is_member_function_pointer<Func>::value, 375 "def_staticmethod(...) called with a non-static member " 376 "function pointer"); 377 py::cpp_function cf(std::forward<Func>(f), py::name(name), 378 py::scope(thisClass), extra...); 379 thisClass.attr(cf.name()) = py::staticmethod(cf); 380 return *this; 381 } 382 383 template <typename Func, typename... Extra> 384 pure_subclass &def_classmethod(const char *name, Func &&f, 385 const Extra &...extra) { 386 static_assert(!std::is_member_function_pointer<Func>::value, 387 "def_classmethod(...) called with a non-static member " 388 "function pointer"); 389 py::cpp_function cf(std::forward<Func>(f), py::name(name), 390 py::scope(thisClass), extra...); 391 thisClass.attr(cf.name()) = 392 py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr())); 393 return *this; 394 } 395 396 py::object get_class() const { return thisClass; } 397 398 protected: 399 py::object superClass; 400 py::object thisClass; 401 }; 402 403 /// Creates a custom subclass of mlir.ir.Attribute, implementing a casting 404 /// constructor and type checking methods. 405 class mlir_attribute_subclass : public pure_subclass { 406 public: 407 using IsAFunctionTy = bool (*)(MlirAttribute); 408 using GetTypeIDFunctionTy = MlirTypeID (*)(); 409 410 /// Subclasses by looking up the super-class dynamically. 411 mlir_attribute_subclass(py::handle scope, const char *attrClassName, 412 IsAFunctionTy isaFunction, 413 GetTypeIDFunctionTy getTypeIDFunction = nullptr) 414 : mlir_attribute_subclass( 415 scope, attrClassName, isaFunction, 416 py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 417 .attr("Attribute"), 418 getTypeIDFunction) {} 419 420 /// Subclasses with a provided mlir.ir.Attribute super-class. This must 421 /// be used if the subclass is being defined in the same extension module 422 /// as the mlir.ir class (otherwise, it will trigger a recursive 423 /// initialization). 424 mlir_attribute_subclass(py::handle scope, const char *typeClassName, 425 IsAFunctionTy isaFunction, const py::object &superCls, 426 GetTypeIDFunctionTy getTypeIDFunction = nullptr) 427 : pure_subclass(scope, typeClassName, superCls) { 428 // Casting constructor. Note that it hard, if not impossible, to properly 429 // call chain to parent `__init__` in pybind11 due to its special handling 430 // for init functions that don't have a fully constructed self-reference, 431 // which makes it impossible to forward it to `__init__` of a superclass. 432 // Instead, provide a custom `__new__` and call that of a superclass, which 433 // eventually calls `__init__` of the superclass. Since attribute subclasses 434 // have no additional members, we can just return the instance thus created 435 // without amending it. 436 std::string captureTypeName( 437 typeClassName); // As string in case if typeClassName is not static. 438 py::cpp_function newCf( 439 [superCls, isaFunction, captureTypeName](py::object cls, 440 py::object otherAttribute) { 441 MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherAttribute); 442 if (!isaFunction(rawAttribute)) { 443 auto origRepr = py::repr(otherAttribute).cast<std::string>(); 444 throw std::invalid_argument( 445 (llvm::Twine("Cannot cast attribute to ") + captureTypeName + 446 " (from " + origRepr + ")") 447 .str()); 448 } 449 py::object self = superCls.attr("__new__")(cls, otherAttribute); 450 return self; 451 }, 452 py::name("__new__"), py::arg("cls"), py::arg("cast_from_attr")); 453 thisClass.attr("__new__") = newCf; 454 455 // 'isinstance' method. 456 def_staticmethod( 457 "isinstance", 458 [isaFunction](MlirAttribute other) { return isaFunction(other); }, 459 py::arg("other_attribute")); 460 def("__repr__", [superCls, captureTypeName](py::object self) { 461 return py::repr(superCls(self)) 462 .attr("replace")(superCls.attr("__name__"), captureTypeName); 463 }); 464 if (getTypeIDFunction) { 465 def_staticmethod("get_static_typeid", 466 [getTypeIDFunction]() { return getTypeIDFunction(); }); 467 py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 468 .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( 469 getTypeIDFunction())(pybind11::cpp_function( 470 [thisClass = thisClass](const py::object &mlirAttribute) { 471 return thisClass(mlirAttribute); 472 })); 473 } 474 } 475 }; 476 477 /// Creates a custom subclass of mlir.ir.Type, implementing a casting 478 /// constructor and type checking methods. 479 class mlir_type_subclass : public pure_subclass { 480 public: 481 using IsAFunctionTy = bool (*)(MlirType); 482 using GetTypeIDFunctionTy = MlirTypeID (*)(); 483 484 /// Subclasses by looking up the super-class dynamically. 485 mlir_type_subclass(py::handle scope, const char *typeClassName, 486 IsAFunctionTy isaFunction, 487 GetTypeIDFunctionTy getTypeIDFunction = nullptr) 488 : mlir_type_subclass( 489 scope, typeClassName, isaFunction, 490 py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type"), 491 getTypeIDFunction) {} 492 493 /// Subclasses with a provided mlir.ir.Type super-class. This must 494 /// be used if the subclass is being defined in the same extension module 495 /// as the mlir.ir class (otherwise, it will trigger a recursive 496 /// initialization). 497 mlir_type_subclass(py::handle scope, const char *typeClassName, 498 IsAFunctionTy isaFunction, const py::object &superCls, 499 GetTypeIDFunctionTy getTypeIDFunction = nullptr) 500 : pure_subclass(scope, typeClassName, superCls) { 501 // Casting constructor. Note that it hard, if not impossible, to properly 502 // call chain to parent `__init__` in pybind11 due to its special handling 503 // for init functions that don't have a fully constructed self-reference, 504 // which makes it impossible to forward it to `__init__` of a superclass. 505 // Instead, provide a custom `__new__` and call that of a superclass, which 506 // eventually calls `__init__` of the superclass. Since attribute subclasses 507 // have no additional members, we can just return the instance thus created 508 // without amending it. 509 std::string captureTypeName( 510 typeClassName); // As string in case if typeClassName is not static. 511 py::cpp_function newCf( 512 [superCls, isaFunction, captureTypeName](py::object cls, 513 py::object otherType) { 514 MlirType rawType = py::cast<MlirType>(otherType); 515 if (!isaFunction(rawType)) { 516 auto origRepr = py::repr(otherType).cast<std::string>(); 517 throw std::invalid_argument((llvm::Twine("Cannot cast type to ") + 518 captureTypeName + " (from " + 519 origRepr + ")") 520 .str()); 521 } 522 py::object self = superCls.attr("__new__")(cls, otherType); 523 return self; 524 }, 525 py::name("__new__"), py::arg("cls"), py::arg("cast_from_type")); 526 thisClass.attr("__new__") = newCf; 527 528 // 'isinstance' method. 529 def_staticmethod( 530 "isinstance", 531 [isaFunction](MlirType other) { return isaFunction(other); }, 532 py::arg("other_type")); 533 def("__repr__", [superCls, captureTypeName](py::object self) { 534 return py::repr(superCls(self)) 535 .attr("replace")(superCls.attr("__name__"), captureTypeName); 536 }); 537 if (getTypeIDFunction) { 538 // 'get_static_typeid' method. 539 // This is modeled as a static method instead of a static property because 540 // `def_property_readonly_static` is not available in `pure_subclass` and 541 // we do not want to introduce the complexity that pybind uses to 542 // implement it. 543 def_staticmethod("get_static_typeid", 544 [getTypeIDFunction]() { return getTypeIDFunction(); }); 545 py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 546 .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( 547 getTypeIDFunction())(pybind11::cpp_function( 548 [thisClass = thisClass](const py::object &mlirType) { 549 return thisClass(mlirType); 550 })); 551 } 552 } 553 }; 554 555 /// Creates a custom subclass of mlir.ir.Value, implementing a casting 556 /// constructor and type checking methods. 557 class mlir_value_subclass : public pure_subclass { 558 public: 559 using IsAFunctionTy = bool (*)(MlirValue); 560 561 /// Subclasses by looking up the super-class dynamically. 562 mlir_value_subclass(py::handle scope, const char *valueClassName, 563 IsAFunctionTy isaFunction) 564 : mlir_value_subclass( 565 scope, valueClassName, isaFunction, 566 py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Value")) { 567 } 568 569 /// Subclasses with a provided mlir.ir.Value super-class. This must 570 /// be used if the subclass is being defined in the same extension module 571 /// as the mlir.ir class (otherwise, it will trigger a recursive 572 /// initialization). 573 mlir_value_subclass(py::handle scope, const char *valueClassName, 574 IsAFunctionTy isaFunction, const py::object &superCls) 575 : pure_subclass(scope, valueClassName, superCls) { 576 // Casting constructor. Note that it hard, if not impossible, to properly 577 // call chain to parent `__init__` in pybind11 due to its special handling 578 // for init functions that don't have a fully constructed self-reference, 579 // which makes it impossible to forward it to `__init__` of a superclass. 580 // Instead, provide a custom `__new__` and call that of a superclass, which 581 // eventually calls `__init__` of the superclass. Since attribute subclasses 582 // have no additional members, we can just return the instance thus created 583 // without amending it. 584 std::string captureValueName( 585 valueClassName); // As string in case if valueClassName is not static. 586 py::cpp_function newCf( 587 [superCls, isaFunction, captureValueName](py::object cls, 588 py::object otherValue) { 589 MlirValue rawValue = py::cast<MlirValue>(otherValue); 590 if (!isaFunction(rawValue)) { 591 auto origRepr = py::repr(otherValue).cast<std::string>(); 592 throw std::invalid_argument((llvm::Twine("Cannot cast value to ") + 593 captureValueName + " (from " + 594 origRepr + ")") 595 .str()); 596 } 597 py::object self = superCls.attr("__new__")(cls, otherValue); 598 return self; 599 }, 600 py::name("__new__"), py::arg("cls"), py::arg("cast_from_value")); 601 thisClass.attr("__new__") = newCf; 602 603 // 'isinstance' method. 604 def_staticmethod( 605 "isinstance", 606 [isaFunction](MlirValue other) { return isaFunction(other); }, 607 py::arg("other_value")); 608 } 609 }; 610 611 } // namespace adaptors 612 613 } // namespace python 614 } // namespace mlir 615 616 #endif // MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H 617