1 //===- PythonTestModuleNanobind.cpp - PythonTest dialect extension --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This is the nanobind edition of the PythonTest dialect module. 9 //===----------------------------------------------------------------------===// 10 11 #include "PythonTestCAPI.h" 12 #include "mlir-c/BuiltinAttributes.h" 13 #include "mlir-c/BuiltinTypes.h" 14 #include "mlir-c/IR.h" 15 #include "mlir/Bindings/Python/Nanobind.h" 16 #include "mlir/Bindings/Python/NanobindAdaptors.h" 17 18 namespace nb = nanobind; 19 using namespace mlir::python::nanobind_adaptors; 20 21 static bool mlirTypeIsARankedIntegerTensor(MlirType t) { 22 return mlirTypeIsARankedTensor(t) && 23 mlirTypeIsAInteger(mlirShapedTypeGetElementType(t)); 24 } 25 26 NB_MODULE(_mlirPythonTestNanobind, m) { 27 m.def( 28 "register_python_test_dialect", 29 [](MlirContext context, bool load) { 30 MlirDialectHandle pythonTestDialect = 31 mlirGetDialectHandle__python_test__(); 32 mlirDialectHandleRegisterDialect(pythonTestDialect, context); 33 if (load) { 34 mlirDialectHandleLoadDialect(pythonTestDialect, context); 35 } 36 }, 37 nb::arg("context"), nb::arg("load") = true); 38 39 m.def( 40 "register_dialect", 41 [](MlirDialectRegistry registry) { 42 MlirDialectHandle pythonTestDialect = 43 mlirGetDialectHandle__python_test__(); 44 mlirDialectHandleInsertDialect(pythonTestDialect, registry); 45 }, 46 nb::arg("registry")); 47 48 mlir_attribute_subclass(m, "TestAttr", 49 mlirAttributeIsAPythonTestTestAttribute, 50 mlirPythonTestTestAttributeGetTypeID) 51 .def_classmethod( 52 "get", 53 [](const nb::object &cls, MlirContext ctx) { 54 return cls(mlirPythonTestTestAttributeGet(ctx)); 55 }, 56 nb::arg("cls"), nb::arg("context").none() = nb::none()); 57 58 mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType, 59 mlirPythonTestTestTypeGetTypeID) 60 .def_classmethod( 61 "get", 62 [](const nb::object &cls, MlirContext ctx) { 63 return cls(mlirPythonTestTestTypeGet(ctx)); 64 }, 65 nb::arg("cls"), nb::arg("context").none() = nb::none()); 66 67 auto typeCls = 68 mlir_type_subclass(m, "TestIntegerRankedTensorType", 69 mlirTypeIsARankedIntegerTensor, 70 nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) 71 .attr("RankedTensorType")) 72 .def_classmethod( 73 "get", 74 [](const nb::object &cls, std::vector<int64_t> shape, 75 unsigned width, MlirContext ctx) { 76 MlirAttribute encoding = mlirAttributeGetNull(); 77 return cls(mlirRankedTensorTypeGet( 78 shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width), 79 encoding)); 80 }, 81 nb::arg("cls"), nb::arg("shape"), nb::arg("width"), 82 nb::arg("context").none() = nb::none()); 83 84 assert(nb::hasattr(typeCls.get_class(), "static_typeid") && 85 "TestIntegerRankedTensorType has no static_typeid"); 86 87 MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID(); 88 89 nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) 90 .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( 91 mlirRankedTensorTypeID, nb::arg("replace") = true)( 92 nanobind::cpp_function([typeCls](const nb::object &mlirType) { 93 return typeCls.get_class()(mlirType); 94 })); 95 96 auto valueCls = mlir_value_subclass(m, "TestTensorValue", 97 mlirTypeIsAPythonTestTestTensorValue) 98 .def("is_null", [](MlirValue &self) { 99 return mlirValueIsNull(self); 100 }); 101 102 nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) 103 .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)( 104 mlirRankedTensorTypeID)( 105 nanobind::cpp_function([valueCls](const nb::object &valueObj) { 106 nb::object capsule = mlirApiObjectToCapsule(valueObj); 107 MlirValue v = mlirPythonCapsuleToValue(capsule.ptr()); 108 MlirType t = mlirValueGetType(v); 109 // This is hyper-specific in order to exercise/test registering a 110 // value caster from cpp (but only for a single test case; see 111 // testTensorValue python_test.py). 112 if (mlirShapedTypeHasStaticShape(t) && 113 mlirShapedTypeGetDimSize(t, 0) == 1 && 114 mlirShapedTypeGetDimSize(t, 1) == 2 && 115 mlirShapedTypeGetDimSize(t, 2) == 3) 116 return valueCls.get_class()(valueObj); 117 return valueObj; 118 })); 119 } 120