xref: /llvm-project/mlir/lib/Bindings/Python/TransformInterpreter.cpp (revision afe75b4d5fcebd6fdd292ca1797de1b35cb687b0)
1 //===- TransformInterpreter.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pybind classes for the transform dialect interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <pybind11/detail/common.h>
14 #include <pybind11/pybind11.h>
15 
16 #include "mlir-c/Dialect/Transform/Interpreter.h"
17 #include "mlir-c/IR.h"
18 #include "mlir-c/Support.h"
19 #include "mlir/Bindings/Python/Diagnostics.h"
20 #include "mlir/Bindings/Python/PybindAdaptors.h"
21 
22 namespace py = pybind11;
23 
24 namespace {
25 struct PyMlirTransformOptions {
26   PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); };
27   PyMlirTransformOptions(PyMlirTransformOptions &&other) {
28     options = other.options;
29     other.options.ptr = nullptr;
30   }
31   PyMlirTransformOptions(const PyMlirTransformOptions &) = delete;
32 
33   ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); }
34 
35   MlirTransformOptions options;
36 };
37 } // namespace
38 
39 static void populateTransformInterpreterSubmodule(py::module &m) {
40   py::class_<PyMlirTransformOptions>(m, "TransformOptions", py::module_local())
41       .def(py::init())
42       .def_property(
43           "expensive_checks",
44           [](const PyMlirTransformOptions &self) {
45             return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
46           },
47           [](PyMlirTransformOptions &self, bool value) {
48             mlirTransformOptionsEnableExpensiveChecks(self.options, value);
49           })
50       .def_property(
51           "enforce_single_top_level_transform_op",
52           [](const PyMlirTransformOptions &self) {
53             return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
54                 self.options);
55           },
56           [](PyMlirTransformOptions &self, bool value) {
57             mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
58                                                                  value);
59           });
60 
61   m.def(
62       "apply_named_sequence",
63       [](MlirOperation payloadRoot, MlirOperation transformRoot,
64          MlirOperation transformModule, const PyMlirTransformOptions &options) {
65         mlir::python::CollectDiagnosticsToStringScope scope(
66             mlirOperationGetContext(transformRoot));
67 
68         // Calling back into Python to invalidate everything under the payload
69         // root. This is awkward, but we don't have access to PyMlirContext
70         // object here otherwise.
71         py::object obj = py::cast(payloadRoot);
72         obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
73 
74         MlirLogicalResult result = mlirTransformApplyNamedSequence(
75             payloadRoot, transformRoot, transformModule, options.options);
76         if (mlirLogicalResultIsSuccess(result))
77           return;
78 
79         throw py::value_error(
80             "Failed to apply named transform sequence.\nDiagnostic message " +
81             scope.takeMessage());
82       },
83       py::arg("payload_root"), py::arg("transform_root"),
84       py::arg("transform_module"),
85       py::arg("transform_options") = PyMlirTransformOptions());
86 
87   m.def(
88       "copy_symbols_and_merge_into",
89       [](MlirOperation target, MlirOperation other) {
90         mlir::python::CollectDiagnosticsToStringScope scope(
91             mlirOperationGetContext(target));
92 
93         MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
94         if (mlirLogicalResultIsFailure(result)) {
95           throw py::value_error(
96               "Failed to merge symbols.\nDiagnostic message " +
97               scope.takeMessage());
98         }
99       },
100       py::arg("target"), py::arg("other"));
101 }
102 
103 PYBIND11_MODULE(_mlirTransformInterpreter, m) {
104   m.doc() = "MLIR Transform dialect interpreter functionality.";
105   populateTransformInterpreterSubmodule(m);
106 }
107