1*392622d0SMaksim Levental //===- PythonTestModule.cpp - Python extension for the PythonTest dialect -===// 2*392622d0SMaksim Levental // 3*392622d0SMaksim Levental // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*392622d0SMaksim Levental // See https://llvm.org/LICENSE.txt for license information. 5*392622d0SMaksim Levental // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*392622d0SMaksim Levental // 7*392622d0SMaksim Levental //===----------------------------------------------------------------------===// 8*392622d0SMaksim Levental // This is the pybind11 edition of the PythonTest dialect module. 9*392622d0SMaksim Levental //===----------------------------------------------------------------------===// 10*392622d0SMaksim Levental 11*392622d0SMaksim Levental #include "PythonTestCAPI.h" 12*392622d0SMaksim Levental #include "mlir-c/BuiltinAttributes.h" 13*392622d0SMaksim Levental #include "mlir-c/BuiltinTypes.h" 14*392622d0SMaksim Levental #include "mlir-c/IR.h" 15*392622d0SMaksim Levental #include "mlir/Bindings/Python/PybindAdaptors.h" 16*392622d0SMaksim Levental 17*392622d0SMaksim Levental namespace py = pybind11; 18*392622d0SMaksim Levental using namespace mlir::python::adaptors; 19*392622d0SMaksim Levental using namespace pybind11::literals; 20*392622d0SMaksim Levental 21*392622d0SMaksim Levental static bool mlirTypeIsARankedIntegerTensor(MlirType t) { 22*392622d0SMaksim Levental return mlirTypeIsARankedTensor(t) && 23*392622d0SMaksim Levental mlirTypeIsAInteger(mlirShapedTypeGetElementType(t)); 24*392622d0SMaksim Levental } 25*392622d0SMaksim Levental 26*392622d0SMaksim Levental PYBIND11_MODULE(_mlirPythonTestPybind11, m) { 27*392622d0SMaksim Levental m.def( 28*392622d0SMaksim Levental "register_python_test_dialect", 29*392622d0SMaksim Levental [](MlirContext context, bool load) { 30*392622d0SMaksim Levental MlirDialectHandle pythonTestDialect = 31*392622d0SMaksim Levental mlirGetDialectHandle__python_test__(); 32*392622d0SMaksim Levental mlirDialectHandleRegisterDialect(pythonTestDialect, context); 33*392622d0SMaksim Levental if (load) { 34*392622d0SMaksim Levental mlirDialectHandleLoadDialect(pythonTestDialect, context); 35*392622d0SMaksim Levental } 36*392622d0SMaksim Levental }, 37*392622d0SMaksim Levental py::arg("context"), py::arg("load") = true); 38*392622d0SMaksim Levental 39*392622d0SMaksim Levental m.def( 40*392622d0SMaksim Levental "register_dialect", 41*392622d0SMaksim Levental [](MlirDialectRegistry registry) { 42*392622d0SMaksim Levental MlirDialectHandle pythonTestDialect = 43*392622d0SMaksim Levental mlirGetDialectHandle__python_test__(); 44*392622d0SMaksim Levental mlirDialectHandleInsertDialect(pythonTestDialect, registry); 45*392622d0SMaksim Levental }, 46*392622d0SMaksim Levental py::arg("registry")); 47*392622d0SMaksim Levental 48*392622d0SMaksim Levental mlir_attribute_subclass(m, "TestAttr", 49*392622d0SMaksim Levental mlirAttributeIsAPythonTestTestAttribute, 50*392622d0SMaksim Levental mlirPythonTestTestAttributeGetTypeID) 51*392622d0SMaksim Levental .def_classmethod( 52*392622d0SMaksim Levental "get", 53*392622d0SMaksim Levental [](const py::object &cls, MlirContext ctx) { 54*392622d0SMaksim Levental return cls(mlirPythonTestTestAttributeGet(ctx)); 55*392622d0SMaksim Levental }, 56*392622d0SMaksim Levental py::arg("cls"), py::arg("context") = py::none()); 57*392622d0SMaksim Levental 58*392622d0SMaksim Levental mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType, 59*392622d0SMaksim Levental mlirPythonTestTestTypeGetTypeID) 60*392622d0SMaksim Levental .def_classmethod( 61*392622d0SMaksim Levental "get", 62*392622d0SMaksim Levental [](const py::object &cls, MlirContext ctx) { 63*392622d0SMaksim Levental return cls(mlirPythonTestTestTypeGet(ctx)); 64*392622d0SMaksim Levental }, 65*392622d0SMaksim Levental py::arg("cls"), py::arg("context") = py::none()); 66*392622d0SMaksim Levental 67*392622d0SMaksim Levental auto typeCls = 68*392622d0SMaksim Levental mlir_type_subclass(m, "TestIntegerRankedTensorType", 69*392622d0SMaksim Levental mlirTypeIsARankedIntegerTensor, 70*392622d0SMaksim Levental py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 71*392622d0SMaksim Levental .attr("RankedTensorType")) 72*392622d0SMaksim Levental .def_classmethod( 73*392622d0SMaksim Levental "get", 74*392622d0SMaksim Levental [](const py::object &cls, std::vector<int64_t> shape, 75*392622d0SMaksim Levental unsigned width, MlirContext ctx) { 76*392622d0SMaksim Levental MlirAttribute encoding = mlirAttributeGetNull(); 77*392622d0SMaksim Levental return cls(mlirRankedTensorTypeGet( 78*392622d0SMaksim Levental shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width), 79*392622d0SMaksim Levental encoding)); 80*392622d0SMaksim Levental }, 81*392622d0SMaksim Levental "cls"_a, "shape"_a, "width"_a, "context"_a = py::none()); 82*392622d0SMaksim Levental 83*392622d0SMaksim Levental assert(py::hasattr(typeCls.get_class(), "static_typeid") && 84*392622d0SMaksim Levental "TestIntegerRankedTensorType has no static_typeid"); 85*392622d0SMaksim Levental 86*392622d0SMaksim Levental MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID(); 87*392622d0SMaksim Levental 88*392622d0SMaksim Levental py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 89*392622d0SMaksim Levental .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(mlirRankedTensorTypeID, 90*392622d0SMaksim Levental "replace"_a = true)( 91*392622d0SMaksim Levental pybind11::cpp_function([typeCls](const py::object &mlirType) { 92*392622d0SMaksim Levental return typeCls.get_class()(mlirType); 93*392622d0SMaksim Levental })); 94*392622d0SMaksim Levental 95*392622d0SMaksim Levental auto valueCls = mlir_value_subclass(m, "TestTensorValue", 96*392622d0SMaksim Levental mlirTypeIsAPythonTestTestTensorValue) 97*392622d0SMaksim Levental .def("is_null", [](MlirValue &self) { 98*392622d0SMaksim Levental return mlirValueIsNull(self); 99*392622d0SMaksim Levental }); 100*392622d0SMaksim Levental 101*392622d0SMaksim Levental py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) 102*392622d0SMaksim Levental .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)( 103*392622d0SMaksim Levental mlirRankedTensorTypeID)( 104*392622d0SMaksim Levental pybind11::cpp_function([valueCls](const py::object &valueObj) { 105*392622d0SMaksim Levental py::object capsule = mlirApiObjectToCapsule(valueObj); 106*392622d0SMaksim Levental MlirValue v = mlirPythonCapsuleToValue(capsule.ptr()); 107*392622d0SMaksim Levental MlirType t = mlirValueGetType(v); 108*392622d0SMaksim Levental // This is hyper-specific in order to exercise/test registering a 109*392622d0SMaksim Levental // value caster from cpp (but only for a single test case; see 110*392622d0SMaksim Levental // testTensorValue python_test.py). 111*392622d0SMaksim Levental if (mlirShapedTypeHasStaticShape(t) && 112*392622d0SMaksim Levental mlirShapedTypeGetDimSize(t, 0) == 1 && 113*392622d0SMaksim Levental mlirShapedTypeGetDimSize(t, 1) == 2 && 114*392622d0SMaksim Levental mlirShapedTypeGetDimSize(t, 2) == 3) 115*392622d0SMaksim Levental return valueCls.get_class()(valueObj); 116*392622d0SMaksim Levental return valueObj; 117*392622d0SMaksim Levental })); 118*392622d0SMaksim Levental } 119