1 //===- MainModule.cpp - Main pybind module --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <tuple> 10 11 #include "PybindUtils.h" 12 13 #include "Globals.h" 14 #include "IRModule.h" 15 #include "Pass.h" 16 17 namespace py = pybind11; 18 using namespace mlir; 19 using namespace mlir::python; 20 21 // ----------------------------------------------------------------------------- 22 // Module initialization. 23 // ----------------------------------------------------------------------------- 24 25 PYBIND11_MODULE(_mlir, m) { 26 m.doc() = "MLIR Python Native Extension"; 27 28 py::class_<PyGlobals>(m, "_Globals", py::module_local()) 29 .def_property("dialect_search_modules", 30 &PyGlobals::getDialectSearchPrefixes, 31 &PyGlobals::setDialectSearchPrefixes) 32 .def( 33 "append_dialect_search_prefix", 34 [](PyGlobals &self, std::string moduleName) { 35 self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 36 self.clearImportCache(); 37 }, 38 py::arg("module_name")) 39 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 40 py::arg("dialect_namespace"), py::arg("dialect_class"), 41 "Testing hook for directly registering a dialect") 42 .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 43 py::arg("operation_name"), py::arg("operation_class"), 44 py::arg("raw_opview_class"), 45 "Testing hook for directly registering an operation"); 46 47 // Aside from making the globals accessible to python, having python manage 48 // it is necessary to make sure it is destroyed (and releases its python 49 // resources) properly. 50 m.attr("globals") = 51 py::cast(new PyGlobals, py::return_value_policy::take_ownership); 52 53 // Registration decorators. 54 m.def( 55 "register_dialect", 56 [](py::object pyClass) { 57 std::string dialectNamespace = 58 pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); 59 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 60 return pyClass; 61 }, 62 py::arg("dialect_class"), 63 "Class decorator for registering a custom Dialect wrapper"); 64 m.def( 65 "register_operation", 66 [](py::object dialectClass) -> py::cpp_function { 67 return py::cpp_function( 68 [dialectClass](py::object opClass) -> py::object { 69 std::string operationName = 70 opClass.attr("OPERATION_NAME").cast<std::string>(); 71 auto rawSubclass = PyOpView::createRawSubclass(opClass); 72 PyGlobals::get().registerOperationImpl(operationName, opClass, 73 rawSubclass); 74 75 // Dict-stuff the new opClass by name onto the dialect class. 76 py::object opClassName = opClass.attr("__name__"); 77 dialectClass.attr(opClassName) = opClass; 78 79 // Now create a special "Raw" subclass that passes through 80 // construction to the OpView parent (bypasses the intermediate 81 // child's __init__). 82 opClass.attr("_Raw") = rawSubclass; 83 return opClass; 84 }); 85 }, 86 py::arg("dialect_class"), 87 "Produce a class decorator for registering an Operation class as part of " 88 "a dialect"); 89 90 // Define and populate IR submodule. 91 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 92 populateIRCore(irModule); 93 populateIRAffine(irModule); 94 populateIRAttributes(irModule); 95 populateIRInterfaces(irModule); 96 populateIRTypes(irModule); 97 98 // Define and populate PassManager submodule. 99 auto passModule = 100 m.def_submodule("passmanager", "MLIR Pass Management Bindings"); 101 populatePassManagerSubmodule(passModule); 102 } 103