xref: /llvm-project/mlir/lib/Bindings/Python/TransformInterpreter.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1 //===- TransformInterpreter.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pybind classes for the transform dialect interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir-c/Dialect/Transform/Interpreter.h"
14 #include "mlir-c/IR.h"
15 #include "mlir-c/Support.h"
16 #include "mlir/Bindings/Python/Diagnostics.h"
17 #include "mlir/Bindings/Python/NanobindAdaptors.h"
18 #include "mlir/Bindings/Python/Nanobind.h"
19 
20 namespace nb = nanobind;
21 
22 namespace {
23 struct PyMlirTransformOptions {
24   PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); };
25   PyMlirTransformOptions(PyMlirTransformOptions &&other) {
26     options = other.options;
27     other.options.ptr = nullptr;
28   }
29   PyMlirTransformOptions(const PyMlirTransformOptions &) = delete;
30 
31   ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); }
32 
33   MlirTransformOptions options;
34 };
35 } // namespace
36 
37 static void populateTransformInterpreterSubmodule(nb::module_ &m) {
38   nb::class_<PyMlirTransformOptions>(m, "TransformOptions")
39       .def(nb::init<>())
40       .def_prop_rw(
41           "expensive_checks",
42           [](const PyMlirTransformOptions &self) {
43             return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
44           },
45           [](PyMlirTransformOptions &self, bool value) {
46             mlirTransformOptionsEnableExpensiveChecks(self.options, value);
47           })
48       .def_prop_rw(
49           "enforce_single_top_level_transform_op",
50           [](const PyMlirTransformOptions &self) {
51             return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
52                 self.options);
53           },
54           [](PyMlirTransformOptions &self, bool value) {
55             mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
56                                                                  value);
57           });
58 
59   m.def(
60       "apply_named_sequence",
61       [](MlirOperation payloadRoot, MlirOperation transformRoot,
62          MlirOperation transformModule, const PyMlirTransformOptions &options) {
63         mlir::python::CollectDiagnosticsToStringScope scope(
64             mlirOperationGetContext(transformRoot));
65 
66         // Calling back into Python to invalidate everything under the payload
67         // root. This is awkward, but we don't have access to PyMlirContext
68         // object here otherwise.
69         nb::object obj = nb::cast(payloadRoot);
70         obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
71 
72         MlirLogicalResult result = mlirTransformApplyNamedSequence(
73             payloadRoot, transformRoot, transformModule, options.options);
74         if (mlirLogicalResultIsSuccess(result))
75           return;
76 
77         throw nb::value_error(
78             ("Failed to apply named transform sequence.\nDiagnostic message " +
79              scope.takeMessage())
80                 .c_str());
81       },
82       nb::arg("payload_root"), nb::arg("transform_root"),
83       nb::arg("transform_module"),
84       nb::arg("transform_options") = PyMlirTransformOptions());
85 
86   m.def(
87       "copy_symbols_and_merge_into",
88       [](MlirOperation target, MlirOperation other) {
89         mlir::python::CollectDiagnosticsToStringScope scope(
90             mlirOperationGetContext(target));
91 
92         MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
93         if (mlirLogicalResultIsFailure(result)) {
94           throw nb::value_error(
95               ("Failed to merge symbols.\nDiagnostic message " +
96                scope.takeMessage())
97                   .c_str());
98         }
99       },
100       nb::arg("target"), nb::arg("other"));
101 }
102 
103 NB_MODULE(_mlirTransformInterpreter, m) {
104   m.doc() = "MLIR Transform dialect interpreter functionality.";
105   populateTransformInterpreterSubmodule(m);
106 }
107