1 //===- MainModule.cpp - Main pybind module --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <tuple> 10 11 #include "PybindUtils.h" 12 13 #include "Globals.h" 14 #include "IRModule.h" 15 #include "Pass.h" 16 17 namespace py = pybind11; 18 using namespace mlir; 19 using namespace mlir::python; 20 21 // ----------------------------------------------------------------------------- 22 // Module initialization. 23 // ----------------------------------------------------------------------------- 24 25 PYBIND11_MODULE(_mlir, m) { 26 m.doc() = "MLIR Python Native Extension"; 27 28 py::class_<PyGlobals>(m, "_Globals", py::module_local()) 29 .def_property("dialect_search_modules", 30 &PyGlobals::getDialectSearchPrefixes, 31 &PyGlobals::setDialectSearchPrefixes) 32 .def( 33 "append_dialect_search_prefix", 34 [](PyGlobals &self, std::string moduleName) { 35 self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 36 self.clearImportCache(); 37 }, 38 py::arg("module_name")) 39 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 40 py::arg("dialect_namespace"), py::arg("dialect_class"), 41 "Testing hook for directly registering a dialect") 42 .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 43 py::arg("operation_name"), py::arg("operation_class"), 44 "Testing hook for directly registering an operation"); 45 46 // Aside from making the globals accessible to python, having python manage 47 // it is necessary to make sure it is destroyed (and releases its python 48 // resources) properly. 49 m.attr("globals") = 50 py::cast(new PyGlobals, py::return_value_policy::take_ownership); 51 52 // Registration decorators. 53 m.def( 54 "register_dialect", 55 [](py::object pyClass) { 56 std::string dialectNamespace = 57 pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); 58 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 59 return pyClass; 60 }, 61 py::arg("dialect_class"), 62 "Class decorator for registering a custom Dialect wrapper"); 63 m.def( 64 "register_operation", 65 [](py::object dialectClass) -> py::cpp_function { 66 return py::cpp_function( 67 [dialectClass](py::object opClass) -> py::object { 68 std::string operationName = 69 opClass.attr("OPERATION_NAME").cast<std::string>(); 70 PyGlobals::get().registerOperationImpl(operationName, opClass); 71 72 // Dict-stuff the new opClass by name onto the dialect class. 73 py::object opClassName = opClass.attr("__name__"); 74 dialectClass.attr(opClassName) = opClass; 75 return opClass; 76 }); 77 }, 78 py::arg("dialect_class"), 79 "Produce a class decorator for registering an Operation class as part of " 80 "a dialect"); 81 82 // Define and populate IR submodule. 83 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 84 populateIRCore(irModule); 85 populateIRAffine(irModule); 86 populateIRAttributes(irModule); 87 populateIRInterfaces(irModule); 88 populateIRTypes(irModule); 89 90 // Define and populate PassManager submodule. 91 auto passModule = 92 m.def_submodule("passmanager", "MLIR Pass Management Bindings"); 93 populatePassManagerSubmodule(passModule); 94 } 95