xref: /llvm-project/mlir/test/python/lib/PythonTestModuleNanobind.cpp (revision 392622d0848b2d0da951d3a4da6fb390a83f812b)
1 //===- PythonTestModuleNanobind.cpp - PythonTest dialect extension --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This is the nanobind edition of the PythonTest dialect module.
9 //===----------------------------------------------------------------------===//
10 
11 #include <nanobind/nanobind.h>
12 #include <nanobind/stl/vector.h>
13 
14 #include "PythonTestCAPI.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/IR.h"
18 #include "mlir/Bindings/Python/NanobindAdaptors.h"
19 
20 namespace nb = nanobind;
21 using namespace mlir::python::nanobind_adaptors;
22 
23 static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
24   return mlirTypeIsARankedTensor(t) &&
25          mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
26 }
27 
28 NB_MODULE(_mlirPythonTestNanobind, m) {
29   m.def(
30       "register_python_test_dialect",
31       [](MlirContext context, bool load) {
32         MlirDialectHandle pythonTestDialect =
33             mlirGetDialectHandle__python_test__();
34         mlirDialectHandleRegisterDialect(pythonTestDialect, context);
35         if (load) {
36           mlirDialectHandleLoadDialect(pythonTestDialect, context);
37         }
38       },
39       nb::arg("context"), nb::arg("load") = true);
40 
41   m.def(
42       "register_dialect",
43       [](MlirDialectRegistry registry) {
44         MlirDialectHandle pythonTestDialect =
45             mlirGetDialectHandle__python_test__();
46         mlirDialectHandleInsertDialect(pythonTestDialect, registry);
47       },
48       nb::arg("registry"));
49 
50   mlir_attribute_subclass(m, "TestAttr",
51                           mlirAttributeIsAPythonTestTestAttribute,
52                           mlirPythonTestTestAttributeGetTypeID)
53       .def_classmethod(
54           "get",
55           [](const nb::object &cls, MlirContext ctx) {
56             return cls(mlirPythonTestTestAttributeGet(ctx));
57           },
58           nb::arg("cls"), nb::arg("context").none() = nb::none());
59 
60   mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
61                      mlirPythonTestTestTypeGetTypeID)
62       .def_classmethod(
63           "get",
64           [](const nb::object &cls, MlirContext ctx) {
65             return cls(mlirPythonTestTestTypeGet(ctx));
66           },
67           nb::arg("cls"), nb::arg("context").none() = nb::none());
68 
69   auto typeCls =
70       mlir_type_subclass(m, "TestIntegerRankedTensorType",
71                          mlirTypeIsARankedIntegerTensor,
72                          nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
73                              .attr("RankedTensorType"))
74           .def_classmethod(
75               "get",
76               [](const nb::object &cls, std::vector<int64_t> shape,
77                  unsigned width, MlirContext ctx) {
78                 MlirAttribute encoding = mlirAttributeGetNull();
79                 return cls(mlirRankedTensorTypeGet(
80                     shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
81                     encoding));
82               },
83               nb::arg("cls"), nb::arg("shape"), nb::arg("width"),
84               nb::arg("context").none() = nb::none());
85 
86   assert(nb::hasattr(typeCls.get_class(), "static_typeid") &&
87          "TestIntegerRankedTensorType has no static_typeid");
88 
89   MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
90 
91   nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
92       .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
93           mlirRankedTensorTypeID, nb::arg("replace") = true)(
94           nanobind::cpp_function([typeCls](const nb::object &mlirType) {
95             return typeCls.get_class()(mlirType);
96           }));
97 
98   auto valueCls = mlir_value_subclass(m, "TestTensorValue",
99                                       mlirTypeIsAPythonTestTestTensorValue)
100                       .def("is_null", [](MlirValue &self) {
101                         return mlirValueIsNull(self);
102                       });
103 
104   nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
105       .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
106           mlirRankedTensorTypeID)(
107           nanobind::cpp_function([valueCls](const nb::object &valueObj) {
108             nb::object capsule = mlirApiObjectToCapsule(valueObj);
109             MlirValue v = mlirPythonCapsuleToValue(capsule.ptr());
110             MlirType t = mlirValueGetType(v);
111             // This is hyper-specific in order to exercise/test registering a
112             // value caster from cpp (but only for a single test case; see
113             // testTensorValue python_test.py).
114             if (mlirShapedTypeHasStaticShape(t) &&
115                 mlirShapedTypeGetDimSize(t, 0) == 1 &&
116                 mlirShapedTypeGetDimSize(t, 1) == 2 &&
117                 mlirShapedTypeGetDimSize(t, 2) == 3)
118               return valueCls.get_class()(valueObj);
119             return valueObj;
120           }));
121 }
122