1 //===- PythonTestModuleNanobind.cpp - PythonTest dialect extension --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This is the nanobind edition of the PythonTest dialect module. 9 //===----------------------------------------------------------------------===// 10 11 #include <nanobind/nanobind.h> 12 #include <nanobind/stl/vector.h> 13 14 #include "PythonTestCAPI.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/IR.h" 18 #include "mlir/Bindings/Python/NanobindAdaptors.h" 19 20 namespace nb = nanobind; 21 using namespace mlir::python::nanobind_adaptors; 22 23 static bool mlirTypeIsARankedIntegerTensor(MlirType t) { 24 return mlirTypeIsARankedTensor(t) && 25 mlirTypeIsAInteger(mlirShapedTypeGetElementType(t)); 26 } 27 28 NB_MODULE(_mlirPythonTestNanobind, m) { 29 m.def( 30 "register_python_test_dialect", 31 [](MlirContext context, bool load) { 32 MlirDialectHandle pythonTestDialect = 33 mlirGetDialectHandle__python_test__(); 34 mlirDialectHandleRegisterDialect(pythonTestDialect, context); 35 if (load) { 36 mlirDialectHandleLoadDialect(pythonTestDialect, context); 37 } 38 }, 39 nb::arg("context"), nb::arg("load") = true); 40 41 m.def( 42 "register_dialect", 43 [](MlirDialectRegistry registry) { 44 MlirDialectHandle pythonTestDialect = 45 mlirGetDialectHandle__python_test__(); 46 mlirDialectHandleInsertDialect(pythonTestDialect, registry); 47 }, 48 nb::arg("registry")); 49 50 mlir_attribute_subclass(m, "TestAttr", 51 mlirAttributeIsAPythonTestTestAttribute, 52 mlirPythonTestTestAttributeGetTypeID) 53 .def_classmethod( 54 "get", 55 [](const nb::object &cls, MlirContext ctx) { 56 return cls(mlirPythonTestTestAttributeGet(ctx)); 57 }, 58 nb::arg("cls"), nb::arg("context").none() = nb::none()); 59 60 mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType, 61 mlirPythonTestTestTypeGetTypeID) 62 .def_classmethod( 63 "get", 64 [](const nb::object &cls, MlirContext ctx) { 65 return cls(mlirPythonTestTestTypeGet(ctx)); 66 }, 67 nb::arg("cls"), nb::arg("context").none() = nb::none()); 68 69 auto typeCls = 70 mlir_type_subclass(m, "TestIntegerRankedTensorType", 71 mlirTypeIsARankedIntegerTensor, 72 nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) 73 .attr("RankedTensorType")) 74 .def_classmethod( 75 "get", 76 [](const nb::object &cls, std::vector<int64_t> shape, 77 unsigned width, MlirContext ctx) { 78 MlirAttribute encoding = mlirAttributeGetNull(); 79 return cls(mlirRankedTensorTypeGet( 80 shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width), 81 encoding)); 82 }, 83 nb::arg("cls"), nb::arg("shape"), nb::arg("width"), 84 nb::arg("context").none() = nb::none()); 85 86 assert(nb::hasattr(typeCls.get_class(), "static_typeid") && 87 "TestIntegerRankedTensorType has no static_typeid"); 88 89 MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID(); 90 91 nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) 92 .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( 93 mlirRankedTensorTypeID, nb::arg("replace") = true)( 94 nanobind::cpp_function([typeCls](const nb::object &mlirType) { 95 return typeCls.get_class()(mlirType); 96 })); 97 98 auto valueCls = mlir_value_subclass(m, "TestTensorValue", 99 mlirTypeIsAPythonTestTestTensorValue) 100 .def("is_null", [](MlirValue &self) { 101 return mlirValueIsNull(self); 102 }); 103 104 nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) 105 .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)( 106 mlirRankedTensorTypeID)( 107 nanobind::cpp_function([valueCls](const nb::object &valueObj) { 108 nb::object capsule = mlirApiObjectToCapsule(valueObj); 109 MlirValue v = mlirPythonCapsuleToValue(capsule.ptr()); 110 MlirType t = mlirValueGetType(v); 111 // This is hyper-specific in order to exercise/test registering a 112 // value caster from cpp (but only for a single test case; see 113 // testTensorValue python_test.py). 114 if (mlirShapedTypeHasStaticShape(t) && 115 mlirShapedTypeGetDimSize(t, 0) == 1 && 116 mlirShapedTypeGetDimSize(t, 1) == 2 && 117 mlirShapedTypeGetDimSize(t, 2) == 3) 118 return valueCls.get_class()(valueObj); 119 return valueObj; 120 })); 121 } 122