1 //===- PythonTestModule.cpp - Python extension for the PythonTest dialect -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This is the pybind11 edition of the PythonTest dialect module. 9 //===----------------------------------------------------------------------===// 10 11 #include "PythonTestCAPI.h" 12 #include "mlir-c/BuiltinAttributes.h" 13 #include "mlir-c/BuiltinTypes.h" 14 #include "mlir-c/IR.h" 15 #include "mlir/Bindings/Python/PybindAdaptors.h" 16 17 namespace py = pybind11; 18 using namespace mlir::python::adaptors; 19 using namespace pybind11::literals; 20 21 static bool mlirTypeIsARankedIntegerTensor(MlirType t) { 22 return mlirTypeIsARankedTensor(t) && 23 mlirTypeIsAInteger(mlirShapedTypeGetElementType(t)); 24 } 25 26 PYBIND11_MODULE(_mlirPythonTestPybind11, m) { 27 m.def( 28 "register_python_test_dialect", 29 [](MlirContext context, bool load) { 30 MlirDialectHandle pythonTestDialect = 31 mlirGetDialectHandle__python_test__(); 32 mlirDialectHandleRegisterDialect(pythonTestDialect, context); 33 if (load) { 34 mlirDialectHandleLoadDialect(pythonTestDialect, context); 35 } 36 }, 37 py::arg("context"), py::arg("load") = true); 38 39 m.def( 40 "register_dialect", 41 [](MlirDialectRegistry registry) { 42 MlirDialectHandle pythonTestDialect = 43 mlirGetDialectHandle__python_test__(); 44 mlirDialectHandleInsertDialect(pythonTestDialect, registry); 45 }, 46 py::arg("registry")); 47 48 mlir_attribute_subclass(m, "TestAttr", 49 mlirAttributeIsAPythonTestTestAttribute, 50 mlirPythonTestTestAttributeGetTypeID) 51 .def_classmethod( 52 "get", 53 [](const py::object &cls, MlirContext ctx) { 54 return cls(mlirPythonTestTestAttributeGet(ctx)); 55 }, 56 py::arg("cls"), py::arg("context") = py::none()); 57 58 mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType, 59 mlirPythonTestTestTypeGetTypeID) 60 .def_classmethod( 61 "get", 62 [](const py::object &cls, MlirContext ctx) { 63 return cls(mlirPythonTestTestTypeGet(ctx)); 64 }, 65 py::arg("cls"), py::arg("context") = py::none()); 66 67 auto typeCls = 68 mlir_type_subclass(m, "TestIntegerRankedTensorType", 69 mlirTypeIsARankedIntegerTensor, 70 py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 71 .attr("RankedTensorType")) 72 .def_classmethod( 73 "get", 74 [](const py::object &cls, std::vector<int64_t> shape, 75 unsigned width, MlirContext ctx) { 76 MlirAttribute encoding = mlirAttributeGetNull(); 77 return cls(mlirRankedTensorTypeGet( 78 shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width), 79 encoding)); 80 }, 81 "cls"_a, "shape"_a, "width"_a, "context"_a = py::none()); 82 83 assert(py::hasattr(typeCls.get_class(), "static_typeid") && 84 "TestIntegerRankedTensorType has no static_typeid"); 85 86 MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID(); 87 88 py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 89 .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(mlirRankedTensorTypeID, 90 "replace"_a = true)( 91 pybind11::cpp_function([typeCls](const py::object &mlirType) { 92 return typeCls.get_class()(mlirType); 93 })); 94 95 auto valueCls = mlir_value_subclass(m, "TestTensorValue", 96 mlirTypeIsAPythonTestTestTensorValue) 97 .def("is_null", [](MlirValue &self) { 98 return mlirValueIsNull(self); 99 }); 100 101 py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 102 .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)( 103 mlirRankedTensorTypeID)( 104 pybind11::cpp_function([valueCls](const py::object &valueObj) { 105 py::object capsule = mlirApiObjectToCapsule(valueObj); 106 MlirValue v = mlirPythonCapsuleToValue(capsule.ptr()); 107 MlirType t = mlirValueGetType(v); 108 // This is hyper-specific in order to exercise/test registering a 109 // value caster from cpp (but only for a single test case; see 110 // testTensorValue python_test.py). 111 if (mlirShapedTypeHasStaticShape(t) && 112 mlirShapedTypeGetDimSize(t, 0) == 1 && 113 mlirShapedTypeGetDimSize(t, 1) == 2 && 114 mlirShapedTypeGetDimSize(t, 2) == 3) 115 return valueCls.get_class()(valueObj); 116 return valueObj; 117 })); 118 } 119