1 //===- MainModule.cpp - Main pybind module --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PybindUtils.h" 10 11 #include "Globals.h" 12 #include "IRModule.h" 13 #include "Pass.h" 14 15 #include <tuple> 16 17 namespace py = pybind11; 18 using namespace mlir; 19 using namespace py::literals; 20 using namespace mlir::python; 21 22 // ----------------------------------------------------------------------------- 23 // Module initialization. 24 // ----------------------------------------------------------------------------- 25 26 PYBIND11_MODULE(_mlir, m) { 27 m.doc() = "MLIR Python Native Extension"; 28 29 py::class_<PyGlobals>(m, "_Globals", py::module_local()) 30 .def_property("dialect_search_modules", 31 &PyGlobals::getDialectSearchPrefixes, 32 &PyGlobals::setDialectSearchPrefixes) 33 .def( 34 "append_dialect_search_prefix", 35 [](PyGlobals &self, std::string moduleName) { 36 self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 37 }, 38 "module_name"_a) 39 .def( 40 "_check_dialect_module_loaded", 41 [](PyGlobals &self, const std::string &dialectNamespace) { 42 return self.loadDialectModule(dialectNamespace); 43 }, 44 "dialect_namespace"_a) 45 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 46 "dialect_namespace"_a, "dialect_class"_a, 47 "Testing hook for directly registering a dialect") 48 .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 49 "operation_name"_a, "operation_class"_a, "replace"_a = false, 50 "Testing hook for directly registering an operation"); 51 52 // Aside from making the globals accessible to python, having python manage 53 // it is necessary to make sure it is destroyed (and releases its python 54 // resources) properly. 55 m.attr("globals") = 56 py::cast(new PyGlobals, py::return_value_policy::take_ownership); 57 58 // Registration decorators. 59 m.def( 60 "register_dialect", 61 [](py::object pyClass) { 62 std::string dialectNamespace = 63 pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); 64 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 65 return pyClass; 66 }, 67 "dialect_class"_a, 68 "Class decorator for registering a custom Dialect wrapper"); 69 m.def( 70 "register_operation", 71 [](const py::object &dialectClass, bool replace) -> py::cpp_function { 72 return py::cpp_function( 73 [dialectClass, replace](py::object opClass) -> py::object { 74 std::string operationName = 75 opClass.attr("OPERATION_NAME").cast<std::string>(); 76 PyGlobals::get().registerOperationImpl(operationName, opClass, 77 replace); 78 79 // Dict-stuff the new opClass by name onto the dialect class. 80 py::object opClassName = opClass.attr("__name__"); 81 dialectClass.attr(opClassName) = opClass; 82 return opClass; 83 }); 84 }, 85 "dialect_class"_a, "replace"_a = false, 86 "Produce a class decorator for registering an Operation class as part of " 87 "a dialect"); 88 m.def( 89 MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, 90 [](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) { 91 PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster), 92 replace); 93 }, 94 "typeid"_a, "type_caster"_a, "replace"_a = false, 95 "Register a type caster for casting MLIR types to custom user types."); 96 97 // Define and populate IR submodule. 98 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 99 populateIRCore(irModule); 100 populateIRAffine(irModule); 101 populateIRAttributes(irModule); 102 populateIRInterfaces(irModule); 103 populateIRTypes(irModule); 104 105 // Define and populate PassManager submodule. 106 auto passModule = 107 m.def_submodule("passmanager", "MLIR Pass Management Bindings"); 108 populatePassManagerSubmodule(passModule); 109 } 110