1 //===- MainModule.cpp - Main pybind module --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <tuple> 10 11 #include "PybindUtils.h" 12 13 #include "Dialects.h" 14 #include "ExecutionEngine.h" 15 #include "Globals.h" 16 #include "IRModule.h" 17 #include "Pass.h" 18 19 namespace py = pybind11; 20 using namespace mlir; 21 using namespace mlir::python; 22 23 // ----------------------------------------------------------------------------- 24 // Module initialization. 25 // ----------------------------------------------------------------------------- 26 27 PYBIND11_MODULE(_mlir, m) { 28 m.doc() = "MLIR Python Native Extension"; 29 30 py::class_<PyGlobals>(m, "_Globals") 31 .def_property("dialect_search_modules", 32 &PyGlobals::getDialectSearchPrefixes, 33 &PyGlobals::setDialectSearchPrefixes) 34 .def("append_dialect_search_prefix", 35 [](PyGlobals &self, std::string moduleName) { 36 self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 37 self.clearImportCache(); 38 }) 39 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 40 "Testing hook for directly registering a dialect") 41 .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 42 "Testing hook for directly registering an operation"); 43 44 // Aside from making the globals accessible to python, having python manage 45 // it is necessary to make sure it is destroyed (and releases its python 46 // resources) properly. 47 m.attr("globals") = 48 py::cast(new PyGlobals, py::return_value_policy::take_ownership); 49 50 // Registration decorators. 51 m.def( 52 "register_dialect", 53 [](py::object pyClass) { 54 std::string dialectNamespace = 55 pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); 56 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 57 return pyClass; 58 }, 59 "Class decorator for registering a custom Dialect wrapper"); 60 m.def( 61 "register_operation", 62 [](py::object dialectClass) -> py::cpp_function { 63 return py::cpp_function( 64 [dialectClass](py::object opClass) -> py::object { 65 std::string operationName = 66 opClass.attr("OPERATION_NAME").cast<std::string>(); 67 auto rawSubclass = PyOpView::createRawSubclass(opClass); 68 PyGlobals::get().registerOperationImpl(operationName, opClass, 69 rawSubclass); 70 71 // Dict-stuff the new opClass by name onto the dialect class. 72 py::object opClassName = opClass.attr("__name__"); 73 dialectClass.attr(opClassName) = opClass; 74 75 // Now create a special "Raw" subclass that passes through 76 // construction to the OpView parent (bypasses the intermediate 77 // child's __init__). 78 opClass.attr("_Raw") = rawSubclass; 79 return opClass; 80 }); 81 }, 82 "Class decorator for registering a custom Operation wrapper"); 83 84 // Define and populate IR submodule. 85 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 86 populateIRCore(irModule); 87 populateIRAffine(irModule); 88 populateIRAttributes(irModule); 89 populateIRTypes(irModule); 90 91 // Define and populate PassManager submodule. 92 auto passModule = 93 m.def_submodule("passmanager", "MLIR Pass Management Bindings"); 94 populatePassManagerSubmodule(passModule); 95 96 // Define and populate ExecutionEngine submodule. 97 auto executionEngineModule = 98 m.def_submodule("execution_engine", "MLIR JIT Execution Engine"); 99 populateExecutionEngineSubmodule(executionEngineModule); 100 101 // Define and populate dialect submodules. 102 auto dialectsModule = m.def_submodule("dialects"); 103 auto linalgModule = dialectsModule.def_submodule("linalg"); 104 populateDialectLinalgSubmodule(linalgModule); 105 populateDialectSparseTensorSubmodule( 106 dialectsModule.def_submodule("sparse_tensor"), irModule); 107 } 108