xref: /llvm-project/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1 //===- ExecutionEngineModule.cpp - Python module for execution engine -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/ExecutionEngine.h"
10 #include "mlir/Bindings/Python/NanobindAdaptors.h"
11 #include "mlir/Bindings/Python/Nanobind.h"
12 
13 namespace nb = nanobind;
14 using namespace mlir;
15 using namespace mlir::python;
16 
17 namespace {
18 
19 /// Owning Wrapper around an ExecutionEngine.
20 class PyExecutionEngine {
21 public:
22   PyExecutionEngine(MlirExecutionEngine executionEngine)
23       : executionEngine(executionEngine) {}
24   PyExecutionEngine(PyExecutionEngine &&other) noexcept
25       : executionEngine(other.executionEngine) {
26     other.executionEngine.ptr = nullptr;
27   }
28   ~PyExecutionEngine() {
29     if (!mlirExecutionEngineIsNull(executionEngine))
30       mlirExecutionEngineDestroy(executionEngine);
31   }
32   MlirExecutionEngine get() { return executionEngine; }
33 
34   void release() {
35     executionEngine.ptr = nullptr;
36     referencedObjects.clear();
37   }
38   nb::object getCapsule() {
39     return nb::steal<nb::object>(mlirPythonExecutionEngineToCapsule(get()));
40   }
41 
42   // Add an object to the list of referenced objects whose lifetime must exceed
43   // those of the ExecutionEngine.
44   void addReferencedObject(const nb::object &obj) {
45     referencedObjects.push_back(obj);
46   }
47 
48   static nb::object createFromCapsule(nb::object capsule) {
49     MlirExecutionEngine rawPm =
50         mlirPythonCapsuleToExecutionEngine(capsule.ptr());
51     if (mlirExecutionEngineIsNull(rawPm))
52       throw nb::python_error();
53     return nb::cast(PyExecutionEngine(rawPm), nb::rv_policy::move);
54   }
55 
56 private:
57   MlirExecutionEngine executionEngine;
58   // We support Python ctypes closures as callbacks. Keep a list of the objects
59   // so that they don't get garbage collected. (The ExecutionEngine itself
60   // just holds raw pointers with no lifetime semantics).
61   std::vector<nb::object> referencedObjects;
62 };
63 
64 } // namespace
65 
66 /// Create the `mlir.execution_engine` module here.
67 NB_MODULE(_mlirExecutionEngine, m) {
68   m.doc() = "MLIR Execution Engine";
69 
70   //----------------------------------------------------------------------------
71   // Mapping of the top-level PassManager
72   //----------------------------------------------------------------------------
73   nb::class_<PyExecutionEngine>(m, "ExecutionEngine")
74       .def(
75           "__init__",
76           [](PyExecutionEngine &self, MlirModule module, int optLevel,
77              const std::vector<std::string> &sharedLibPaths,
78              bool enableObjectDump) {
79             llvm::SmallVector<MlirStringRef, 4> libPaths;
80             for (const std::string &path : sharedLibPaths)
81               libPaths.push_back({path.c_str(), path.length()});
82             MlirExecutionEngine executionEngine =
83                 mlirExecutionEngineCreate(module, optLevel, libPaths.size(),
84                                           libPaths.data(), enableObjectDump);
85             if (mlirExecutionEngineIsNull(executionEngine))
86               throw std::runtime_error(
87                   "Failure while creating the ExecutionEngine.");
88             new (&self) PyExecutionEngine(executionEngine);
89           },
90           nb::arg("module"), nb::arg("opt_level") = 2,
91           nb::arg("shared_libs") = nb::list(),
92           nb::arg("enable_object_dump") = true,
93           "Create a new ExecutionEngine instance for the given Module. The "
94           "module must contain only dialects that can be translated to LLVM. "
95           "Perform transformations and code generation at the optimization "
96           "level `opt_level` if specified, or otherwise at the default "
97           "level of two (-O2). Load a list of libraries specified in "
98           "`shared_libs`.")
99       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyExecutionEngine::getCapsule)
100       .def("_testing_release", &PyExecutionEngine::release,
101            "Releases (leaks) the backing ExecutionEngine (for testing purpose)")
102       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyExecutionEngine::createFromCapsule)
103       .def(
104           "raw_lookup",
105           [](PyExecutionEngine &executionEngine, const std::string &func) {
106             auto *res = mlirExecutionEngineLookupPacked(
107                 executionEngine.get(),
108                 mlirStringRefCreate(func.c_str(), func.size()));
109             return reinterpret_cast<uintptr_t>(res);
110           },
111           nb::arg("func_name"),
112           "Lookup function `func` in the ExecutionEngine.")
113       .def(
114           "raw_register_runtime",
115           [](PyExecutionEngine &executionEngine, const std::string &name,
116              nb::object callbackObj) {
117             executionEngine.addReferencedObject(callbackObj);
118             uintptr_t rawSym =
119                 nb::cast<uintptr_t>(nb::getattr(callbackObj, "value"));
120             mlirExecutionEngineRegisterSymbol(
121                 executionEngine.get(),
122                 mlirStringRefCreate(name.c_str(), name.size()),
123                 reinterpret_cast<void *>(rawSym));
124           },
125           nb::arg("name"), nb::arg("callback"),
126           "Register `callback` as the runtime symbol `name`.")
127       .def(
128           "dump_to_object_file",
129           [](PyExecutionEngine &executionEngine, const std::string &fileName) {
130             mlirExecutionEngineDumpToObjectFile(
131                 executionEngine.get(),
132                 mlirStringRefCreate(fileName.c_str(), fileName.size()));
133           },
134           nb::arg("file_name"), "Dump ExecutionEngine to an object file.");
135 }
136