xref: /llvm-project/mlir/test/python/lib/PythonTestModuleNanobind.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1392622d0SMaksim Levental //===- PythonTestModuleNanobind.cpp - PythonTest dialect extension --------===//
2392622d0SMaksim Levental //
3392622d0SMaksim Levental // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4392622d0SMaksim Levental // See https://llvm.org/LICENSE.txt for license information.
5392622d0SMaksim Levental // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6392622d0SMaksim Levental //
7392622d0SMaksim Levental //===----------------------------------------------------------------------===//
8392622d0SMaksim Levental // This is the nanobind edition of the PythonTest dialect module.
9392622d0SMaksim Levental //===----------------------------------------------------------------------===//
10392622d0SMaksim Levental 
11392622d0SMaksim Levental #include "PythonTestCAPI.h"
12392622d0SMaksim Levental #include "mlir-c/BuiltinAttributes.h"
13392622d0SMaksim Levental #include "mlir-c/BuiltinTypes.h"
14392622d0SMaksim Levental #include "mlir-c/IR.h"
15*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h"
16392622d0SMaksim Levental #include "mlir/Bindings/Python/NanobindAdaptors.h"
17392622d0SMaksim Levental 
18392622d0SMaksim Levental namespace nb = nanobind;
19392622d0SMaksim Levental using namespace mlir::python::nanobind_adaptors;
20392622d0SMaksim Levental 
21392622d0SMaksim Levental static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
22392622d0SMaksim Levental   return mlirTypeIsARankedTensor(t) &&
23392622d0SMaksim Levental          mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
24392622d0SMaksim Levental }
25392622d0SMaksim Levental 
26392622d0SMaksim Levental NB_MODULE(_mlirPythonTestNanobind, m) {
27392622d0SMaksim Levental   m.def(
28392622d0SMaksim Levental       "register_python_test_dialect",
29392622d0SMaksim Levental       [](MlirContext context, bool load) {
30392622d0SMaksim Levental         MlirDialectHandle pythonTestDialect =
31392622d0SMaksim Levental             mlirGetDialectHandle__python_test__();
32392622d0SMaksim Levental         mlirDialectHandleRegisterDialect(pythonTestDialect, context);
33392622d0SMaksim Levental         if (load) {
34392622d0SMaksim Levental           mlirDialectHandleLoadDialect(pythonTestDialect, context);
35392622d0SMaksim Levental         }
36392622d0SMaksim Levental       },
37392622d0SMaksim Levental       nb::arg("context"), nb::arg("load") = true);
38392622d0SMaksim Levental 
39392622d0SMaksim Levental   m.def(
40392622d0SMaksim Levental       "register_dialect",
41392622d0SMaksim Levental       [](MlirDialectRegistry registry) {
42392622d0SMaksim Levental         MlirDialectHandle pythonTestDialect =
43392622d0SMaksim Levental             mlirGetDialectHandle__python_test__();
44392622d0SMaksim Levental         mlirDialectHandleInsertDialect(pythonTestDialect, registry);
45392622d0SMaksim Levental       },
46392622d0SMaksim Levental       nb::arg("registry"));
47392622d0SMaksim Levental 
48392622d0SMaksim Levental   mlir_attribute_subclass(m, "TestAttr",
49392622d0SMaksim Levental                           mlirAttributeIsAPythonTestTestAttribute,
50392622d0SMaksim Levental                           mlirPythonTestTestAttributeGetTypeID)
51392622d0SMaksim Levental       .def_classmethod(
52392622d0SMaksim Levental           "get",
53392622d0SMaksim Levental           [](const nb::object &cls, MlirContext ctx) {
54392622d0SMaksim Levental             return cls(mlirPythonTestTestAttributeGet(ctx));
55392622d0SMaksim Levental           },
56392622d0SMaksim Levental           nb::arg("cls"), nb::arg("context").none() = nb::none());
57392622d0SMaksim Levental 
58392622d0SMaksim Levental   mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
59392622d0SMaksim Levental                      mlirPythonTestTestTypeGetTypeID)
60392622d0SMaksim Levental       .def_classmethod(
61392622d0SMaksim Levental           "get",
62392622d0SMaksim Levental           [](const nb::object &cls, MlirContext ctx) {
63392622d0SMaksim Levental             return cls(mlirPythonTestTestTypeGet(ctx));
64392622d0SMaksim Levental           },
65392622d0SMaksim Levental           nb::arg("cls"), nb::arg("context").none() = nb::none());
66392622d0SMaksim Levental 
67392622d0SMaksim Levental   auto typeCls =
68392622d0SMaksim Levental       mlir_type_subclass(m, "TestIntegerRankedTensorType",
69392622d0SMaksim Levental                          mlirTypeIsARankedIntegerTensor,
70392622d0SMaksim Levental                          nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
71392622d0SMaksim Levental                              .attr("RankedTensorType"))
72392622d0SMaksim Levental           .def_classmethod(
73392622d0SMaksim Levental               "get",
74392622d0SMaksim Levental               [](const nb::object &cls, std::vector<int64_t> shape,
75392622d0SMaksim Levental                  unsigned width, MlirContext ctx) {
76392622d0SMaksim Levental                 MlirAttribute encoding = mlirAttributeGetNull();
77392622d0SMaksim Levental                 return cls(mlirRankedTensorTypeGet(
78392622d0SMaksim Levental                     shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
79392622d0SMaksim Levental                     encoding));
80392622d0SMaksim Levental               },
81392622d0SMaksim Levental               nb::arg("cls"), nb::arg("shape"), nb::arg("width"),
82392622d0SMaksim Levental               nb::arg("context").none() = nb::none());
83392622d0SMaksim Levental 
84392622d0SMaksim Levental   assert(nb::hasattr(typeCls.get_class(), "static_typeid") &&
85392622d0SMaksim Levental          "TestIntegerRankedTensorType has no static_typeid");
86392622d0SMaksim Levental 
87392622d0SMaksim Levental   MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
88392622d0SMaksim Levental 
89392622d0SMaksim Levental   nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
90392622d0SMaksim Levental       .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
91392622d0SMaksim Levental           mlirRankedTensorTypeID, nb::arg("replace") = true)(
92392622d0SMaksim Levental           nanobind::cpp_function([typeCls](const nb::object &mlirType) {
93392622d0SMaksim Levental             return typeCls.get_class()(mlirType);
94392622d0SMaksim Levental           }));
95392622d0SMaksim Levental 
96392622d0SMaksim Levental   auto valueCls = mlir_value_subclass(m, "TestTensorValue",
97392622d0SMaksim Levental                                       mlirTypeIsAPythonTestTestTensorValue)
98392622d0SMaksim Levental                       .def("is_null", [](MlirValue &self) {
99392622d0SMaksim Levental                         return mlirValueIsNull(self);
100392622d0SMaksim Levental                       });
101392622d0SMaksim Levental 
102392622d0SMaksim Levental   nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
103392622d0SMaksim Levental       .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
104392622d0SMaksim Levental           mlirRankedTensorTypeID)(
105392622d0SMaksim Levental           nanobind::cpp_function([valueCls](const nb::object &valueObj) {
106392622d0SMaksim Levental             nb::object capsule = mlirApiObjectToCapsule(valueObj);
107392622d0SMaksim Levental             MlirValue v = mlirPythonCapsuleToValue(capsule.ptr());
108392622d0SMaksim Levental             MlirType t = mlirValueGetType(v);
109392622d0SMaksim Levental             // This is hyper-specific in order to exercise/test registering a
110392622d0SMaksim Levental             // value caster from cpp (but only for a single test case; see
111392622d0SMaksim Levental             // testTensorValue python_test.py).
112392622d0SMaksim Levental             if (mlirShapedTypeHasStaticShape(t) &&
113392622d0SMaksim Levental                 mlirShapedTypeGetDimSize(t, 0) == 1 &&
114392622d0SMaksim Levental                 mlirShapedTypeGetDimSize(t, 1) == 2 &&
115392622d0SMaksim Levental                 mlirShapedTypeGetDimSize(t, 2) == 3)
116392622d0SMaksim Levental               return valueCls.get_class()(valueObj);
117392622d0SMaksim Levental             return valueObj;
118392622d0SMaksim Levental           }));
119392622d0SMaksim Levental }
120