1392622d0SMaksim Levental //===- PythonTestModuleNanobind.cpp - PythonTest dialect extension --------===// 2392622d0SMaksim Levental // 3392622d0SMaksim Levental // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4392622d0SMaksim Levental // See https://llvm.org/LICENSE.txt for license information. 5392622d0SMaksim Levental // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6392622d0SMaksim Levental // 7392622d0SMaksim Levental //===----------------------------------------------------------------------===// 8392622d0SMaksim Levental // This is the nanobind edition of the PythonTest dialect module. 9392622d0SMaksim Levental //===----------------------------------------------------------------------===// 10392622d0SMaksim Levental 11392622d0SMaksim Levental #include "PythonTestCAPI.h" 12392622d0SMaksim Levental #include "mlir-c/BuiltinAttributes.h" 13392622d0SMaksim Levental #include "mlir-c/BuiltinTypes.h" 14392622d0SMaksim Levental #include "mlir-c/IR.h" 15*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h" 16392622d0SMaksim Levental #include "mlir/Bindings/Python/NanobindAdaptors.h" 17392622d0SMaksim Levental 18392622d0SMaksim Levental namespace nb = nanobind; 19392622d0SMaksim Levental using namespace mlir::python::nanobind_adaptors; 20392622d0SMaksim Levental 21392622d0SMaksim Levental static bool mlirTypeIsARankedIntegerTensor(MlirType t) { 22392622d0SMaksim Levental return mlirTypeIsARankedTensor(t) && 23392622d0SMaksim Levental mlirTypeIsAInteger(mlirShapedTypeGetElementType(t)); 24392622d0SMaksim Levental } 25392622d0SMaksim Levental 26392622d0SMaksim Levental NB_MODULE(_mlirPythonTestNanobind, m) { 27392622d0SMaksim Levental m.def( 28392622d0SMaksim Levental "register_python_test_dialect", 29392622d0SMaksim Levental [](MlirContext context, bool load) { 30392622d0SMaksim Levental MlirDialectHandle pythonTestDialect = 31392622d0SMaksim Levental mlirGetDialectHandle__python_test__(); 32392622d0SMaksim Levental mlirDialectHandleRegisterDialect(pythonTestDialect, context); 33392622d0SMaksim Levental if (load) { 34392622d0SMaksim Levental mlirDialectHandleLoadDialect(pythonTestDialect, context); 35392622d0SMaksim Levental } 36392622d0SMaksim Levental }, 37392622d0SMaksim Levental nb::arg("context"), nb::arg("load") = true); 38392622d0SMaksim Levental 39392622d0SMaksim Levental m.def( 40392622d0SMaksim Levental "register_dialect", 41392622d0SMaksim Levental [](MlirDialectRegistry registry) { 42392622d0SMaksim Levental MlirDialectHandle pythonTestDialect = 43392622d0SMaksim Levental mlirGetDialectHandle__python_test__(); 44392622d0SMaksim Levental mlirDialectHandleInsertDialect(pythonTestDialect, registry); 45392622d0SMaksim Levental }, 46392622d0SMaksim Levental nb::arg("registry")); 47392622d0SMaksim Levental 48392622d0SMaksim Levental mlir_attribute_subclass(m, "TestAttr", 49392622d0SMaksim Levental mlirAttributeIsAPythonTestTestAttribute, 50392622d0SMaksim Levental mlirPythonTestTestAttributeGetTypeID) 51392622d0SMaksim Levental .def_classmethod( 52392622d0SMaksim Levental "get", 53392622d0SMaksim Levental [](const nb::object &cls, MlirContext ctx) { 54392622d0SMaksim Levental return cls(mlirPythonTestTestAttributeGet(ctx)); 55392622d0SMaksim Levental }, 56392622d0SMaksim Levental nb::arg("cls"), nb::arg("context").none() = nb::none()); 57392622d0SMaksim Levental 58392622d0SMaksim Levental mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType, 59392622d0SMaksim Levental mlirPythonTestTestTypeGetTypeID) 60392622d0SMaksim Levental .def_classmethod( 61392622d0SMaksim Levental "get", 62392622d0SMaksim Levental [](const nb::object &cls, MlirContext ctx) { 63392622d0SMaksim Levental return cls(mlirPythonTestTestTypeGet(ctx)); 64392622d0SMaksim Levental }, 65392622d0SMaksim Levental nb::arg("cls"), nb::arg("context").none() = nb::none()); 66392622d0SMaksim Levental 67392622d0SMaksim Levental auto typeCls = 68392622d0SMaksim Levental mlir_type_subclass(m, "TestIntegerRankedTensorType", 69392622d0SMaksim Levental mlirTypeIsARankedIntegerTensor, 70392622d0SMaksim Levental nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) 71392622d0SMaksim Levental .attr("RankedTensorType")) 72392622d0SMaksim Levental .def_classmethod( 73392622d0SMaksim Levental "get", 74392622d0SMaksim Levental [](const nb::object &cls, std::vector<int64_t> shape, 75392622d0SMaksim Levental unsigned width, MlirContext ctx) { 76392622d0SMaksim Levental MlirAttribute encoding = mlirAttributeGetNull(); 77392622d0SMaksim Levental return cls(mlirRankedTensorTypeGet( 78392622d0SMaksim Levental shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width), 79392622d0SMaksim Levental encoding)); 80392622d0SMaksim Levental }, 81392622d0SMaksim Levental nb::arg("cls"), nb::arg("shape"), nb::arg("width"), 82392622d0SMaksim Levental nb::arg("context").none() = nb::none()); 83392622d0SMaksim Levental 84392622d0SMaksim Levental assert(nb::hasattr(typeCls.get_class(), "static_typeid") && 85392622d0SMaksim Levental "TestIntegerRankedTensorType has no static_typeid"); 86392622d0SMaksim Levental 87392622d0SMaksim Levental MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID(); 88392622d0SMaksim Levental 89392622d0SMaksim Levental nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) 90392622d0SMaksim Levental .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( 91392622d0SMaksim Levental mlirRankedTensorTypeID, nb::arg("replace") = true)( 92392622d0SMaksim Levental nanobind::cpp_function([typeCls](const nb::object &mlirType) { 93392622d0SMaksim Levental return typeCls.get_class()(mlirType); 94392622d0SMaksim Levental })); 95392622d0SMaksim Levental 96392622d0SMaksim Levental auto valueCls = mlir_value_subclass(m, "TestTensorValue", 97392622d0SMaksim Levental mlirTypeIsAPythonTestTestTensorValue) 98392622d0SMaksim Levental .def("is_null", [](MlirValue &self) { 99392622d0SMaksim Levental return mlirValueIsNull(self); 100392622d0SMaksim Levental }); 101392622d0SMaksim Levental 102392622d0SMaksim Levental nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) 103392622d0SMaksim Levental .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)( 104392622d0SMaksim Levental mlirRankedTensorTypeID)( 105392622d0SMaksim Levental nanobind::cpp_function([valueCls](const nb::object &valueObj) { 106392622d0SMaksim Levental nb::object capsule = mlirApiObjectToCapsule(valueObj); 107392622d0SMaksim Levental MlirValue v = mlirPythonCapsuleToValue(capsule.ptr()); 108392622d0SMaksim Levental MlirType t = mlirValueGetType(v); 109392622d0SMaksim Levental // This is hyper-specific in order to exercise/test registering a 110392622d0SMaksim Levental // value caster from cpp (but only for a single test case; see 111392622d0SMaksim Levental // testTensorValue python_test.py). 112392622d0SMaksim Levental if (mlirShapedTypeHasStaticShape(t) && 113392622d0SMaksim Levental mlirShapedTypeGetDimSize(t, 0) == 1 && 114392622d0SMaksim Levental mlirShapedTypeGetDimSize(t, 1) == 2 && 115392622d0SMaksim Levental mlirShapedTypeGetDimSize(t, 2) == 3) 116392622d0SMaksim Levental return valueCls.get_class()(valueObj); 117392622d0SMaksim Levental return valueObj; 118392622d0SMaksim Levental })); 119392622d0SMaksim Levental } 120