xref: /llvm-project/mlir/lib/Bindings/Python/TransformInterpreter.cpp (revision 73140daebbf522dbb14dc4b2f3c67dc0aa1a62dd)
1 //===- TransformInterpreter.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pybind classes for the transform dialect interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir-c/Dialect/Transform/Interpreter.h"
14 #include "mlir-c/IR.h"
15 #include "mlir-c/Support.h"
16 #include "mlir/Bindings/Python/PybindAdaptors.h"
17 
18 #include <pybind11/detail/common.h>
19 #include <pybind11/pybind11.h>
20 
21 namespace py = pybind11;
22 
23 namespace {
24 struct PyMlirTransformOptions {
25   PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); };
26   PyMlirTransformOptions(PyMlirTransformOptions &&other) {
27     options = other.options;
28     other.options.ptr = nullptr;
29   }
30   PyMlirTransformOptions(const PyMlirTransformOptions &) = delete;
31 
32   ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); }
33 
34   MlirTransformOptions options;
35 };
36 } // namespace
37 
38 static void populateTransformInterpreterSubmodule(py::module &m) {
39   py::class_<PyMlirTransformOptions>(m, "TransformOptions", py::module_local())
40       .def(py::init())
41       .def_property(
42           "expensive_checks",
43           [](const PyMlirTransformOptions &self) {
44             return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
45           },
46           [](PyMlirTransformOptions &self, bool value) {
47             mlirTransformOptionsEnableExpensiveChecks(self.options, value);
48           })
49       .def_property(
50           "enforce_single_top_level_transform_op",
51           [](const PyMlirTransformOptions &self) {
52             return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
53                 self.options);
54           },
55           [](PyMlirTransformOptions &self, bool value) {
56             mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
57                                                                  value);
58           });
59 
60   m.def(
61       "apply_named_sequence",
62       [](MlirOperation payloadRoot, MlirOperation transformRoot,
63          MlirOperation transformModule, const PyMlirTransformOptions &options) {
64         mlir::python::CollectDiagnosticsToStringScope scope(
65             mlirOperationGetContext(transformRoot));
66 
67         // Calling back into Python to invalidate everything under the payload
68         // root. This is awkward, but we don't have access to PyMlirContext
69         // object here otherwise.
70         py::object obj = py::cast(payloadRoot);
71         obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
72 
73         MlirLogicalResult result = mlirTransformApplyNamedSequence(
74             payloadRoot, transformRoot, transformModule, options.options);
75         if (mlirLogicalResultIsSuccess(result))
76           return;
77 
78         throw py::value_error(
79             "Failed to apply named transform sequence.\nDiagnostic message " +
80             scope.takeMessage());
81       },
82       py::arg("payload_root"), py::arg("transform_root"),
83       py::arg("transform_module"),
84       py::arg("transform_options") = PyMlirTransformOptions());
85 
86   m.def(
87       "copy_symbols_and_merge_into",
88       [](MlirOperation target, MlirOperation other) {
89         mlir::python::CollectDiagnosticsToStringScope scope(
90             mlirOperationGetContext(target));
91 
92         MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
93         if (mlirLogicalResultIsFailure(result)) {
94           throw py::value_error(
95               "Failed to merge symbols.\nDiagnostic message " +
96               scope.takeMessage());
97         }
98       },
99       py::arg("target"), py::arg("other"));
100 }
101 
102 PYBIND11_MODULE(_mlirTransformInterpreter, m) {
103   m.doc() = "MLIR Transform dialect interpreter functionality.";
104   populateTransformInterpreterSubmodule(m);
105 }
106