1 //===- MainModule.cpp - Main pybind module --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <tuple> 10 11 #include "PybindUtils.h" 12 13 #include "Dialects.h" 14 #include "Globals.h" 15 #include "IRModule.h" 16 #include "Pass.h" 17 18 namespace py = pybind11; 19 using namespace mlir; 20 using namespace mlir::python; 21 22 // ----------------------------------------------------------------------------- 23 // Module initialization. 24 // ----------------------------------------------------------------------------- 25 26 PYBIND11_MODULE(_mlir, m) { 27 m.doc() = "MLIR Python Native Extension"; 28 29 py::class_<PyGlobals>(m, "_Globals") 30 .def_property("dialect_search_modules", 31 &PyGlobals::getDialectSearchPrefixes, 32 &PyGlobals::setDialectSearchPrefixes) 33 .def("append_dialect_search_prefix", 34 [](PyGlobals &self, std::string moduleName) { 35 self.getDialectSearchPrefixes().push_back(std::move(moduleName)); 36 self.clearImportCache(); 37 }) 38 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, 39 "Testing hook for directly registering a dialect") 40 .def("_register_operation_impl", &PyGlobals::registerOperationImpl, 41 "Testing hook for directly registering an operation"); 42 43 // Aside from making the globals accessible to python, having python manage 44 // it is necessary to make sure it is destroyed (and releases its python 45 // resources) properly. 46 m.attr("globals") = 47 py::cast(new PyGlobals, py::return_value_policy::take_ownership); 48 49 // Registration decorators. 50 m.def( 51 "register_dialect", 52 [](py::object pyClass) { 53 std::string dialectNamespace = 54 pyClass.attr("DIALECT_NAMESPACE").cast<std::string>(); 55 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); 56 return pyClass; 57 }, 58 "Class decorator for registering a custom Dialect wrapper"); 59 m.def( 60 "register_operation", 61 [](py::object dialectClass) -> py::cpp_function { 62 return py::cpp_function( 63 [dialectClass](py::object opClass) -> py::object { 64 std::string operationName = 65 opClass.attr("OPERATION_NAME").cast<std::string>(); 66 auto rawSubclass = PyOpView::createRawSubclass(opClass); 67 PyGlobals::get().registerOperationImpl(operationName, opClass, 68 rawSubclass); 69 70 // Dict-stuff the new opClass by name onto the dialect class. 71 py::object opClassName = opClass.attr("__name__"); 72 dialectClass.attr(opClassName) = opClass; 73 74 // Now create a special "Raw" subclass that passes through 75 // construction to the OpView parent (bypasses the intermediate 76 // child's __init__). 77 opClass.attr("_Raw") = rawSubclass; 78 return opClass; 79 }); 80 }, 81 "Class decorator for registering a custom Operation wrapper"); 82 83 // Define and populate IR submodule. 84 auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); 85 populateIRCore(irModule); 86 populateIRAffine(irModule); 87 populateIRAttributes(irModule); 88 populateIRTypes(irModule); 89 90 // Define and populate PassManager submodule. 91 auto passModule = 92 m.def_submodule("passmanager", "MLIR Pass Management Bindings"); 93 populatePassManagerSubmodule(passModule); 94 95 // Define and populate dialect submodules. 96 auto dialectsModule = m.def_submodule("dialects"); 97 auto linalgModule = dialectsModule.def_submodule("linalg"); 98 populateDialectLinalgSubmodule(linalgModule); 99 populateDialectSparseTensorSubmodule( 100 dialectsModule.def_submodule("sparse_tensor"), irModule); 101 } 102