1 //===- MainModule.cpp - Main pybind module --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PybindUtils.h" 10 11 #include "Globals.h" 12 #include "IRModule.h" 13 #include "Pass.h" 14 15 namespace py = pybind11; 16 using namespace mlir; 17 using namespace py::literals; 18 using namespace mlir::python; 19 20 // ----------------------------------------------------------------------------- 21 // Module initialization. 22 // ----------------------------------------------------------------------------- 23 24 PYBIND11_MODULE(_mlir, m) { 25 m.doc() = "MLIR Python Native Extension"; 26 27 py::class_<PyGlobals>(m, "_Globals", py::module_local()) 28 .def_property("dialect_search_modules", 29 &PyGlobals::getDialectSearchPrefixes, 30 &PyGlobals::setDialectSearchPrefixes) 31 .def( 32 "append_dialect_search_prefix", 33 [](PyGlobals &self, std::string moduleName) { 34 self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 35 }, 36 "module_name"_a) 37 .def( 38 "_check_dialect_module_loaded", 39 [](PyGlobals &self, const std::string &dialectNamespace) { 40 return self.loadDialectModule(dialectNamespace); 41 }, 42 "dialect_namespace"_a) 43 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 44 "dialect_namespace"_a, "dialect_class"_a, 45 "Testing hook for directly registering a dialect") 46 .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 47 "operation_name"_a, "operation_class"_a, py::kw_only(), 48 "replace"_a = false, 49 "Testing hook for directly registering an operation"); 50 51 // Aside from making the globals accessible to python, having python manage 52 // it is necessary to make sure it is destroyed (and releases its python 53 // resources) properly. 54 m.attr("globals") = 55 py::cast(new PyGlobals, py::return_value_policy::take_ownership); 56 57 // Registration decorators. 58 m.def( 59 "register_dialect", 60 [](py::object pyClass) { 61 std::string dialectNamespace = 62 pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); 63 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 64 return pyClass; 65 }, 66 "dialect_class"_a, 67 "Class decorator for registering a custom Dialect wrapper"); 68 m.def( 69 "register_operation", 70 [](const py::object &dialectClass, bool replace) -> py::cpp_function { 71 return py::cpp_function( 72 [dialectClass, replace](py::object opClass) -> py::object { 73 std::string operationName = 74 opClass.attr("OPERATION_NAME").cast<std::string>(); 75 PyGlobals::get().registerOperationImpl(operationName, opClass, 76 replace); 77 78 // Dict-stuff the new opClass by name onto the dialect class. 79 py::object opClassName = opClass.attr("__name__"); 80 dialectClass.attr(opClassName) = opClass; 81 return opClass; 82 }); 83 }, 84 "dialect_class"_a, py::kw_only(), "replace"_a = false, 85 "Produce a class decorator for registering an Operation class as part of " 86 "a dialect"); 87 m.def( 88 MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, 89 [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { 90 return py::cpp_function([mlirTypeID, 91 replace](py::object typeCaster) -> py::object { 92 PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace); 93 return typeCaster; 94 }); 95 }, 96 "typeid"_a, py::kw_only(), "replace"_a = false, 97 "Register a type caster for casting MLIR types to custom user types."); 98 m.def( 99 MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, 100 [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { 101 return py::cpp_function( 102 [mlirTypeID, replace](py::object valueCaster) -> py::object { 103 PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster, 104 replace); 105 return valueCaster; 106 }); 107 }, 108 "typeid"_a, py::kw_only(), "replace"_a = false, 109 "Register a value caster for casting MLIR values to custom user values."); 110 111 // Define and populate IR submodule. 112 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 113 populateIRCore(irModule); 114 populateIRAffine(irModule); 115 populateIRAttributes(irModule); 116 populateIRInterfaces(irModule); 117 populateIRTypes(irModule); 118 119 // Define and populate PassManager submodule. 120 auto passModule = 121 m.def_submodule("passmanager", "MLIR Pass Management Bindings"); 122 populatePassManagerSubmodule(passModule); 123 } 124