xref: /llvm-project/mlir/lib/Bindings/Python/TransformInterpreter.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
191f11611SOleksandr "Alex" Zinenko //===- TransformInterpreter.cpp -------------------------------------------===//
291f11611SOleksandr "Alex" Zinenko //
391f11611SOleksandr "Alex" Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
491f11611SOleksandr "Alex" Zinenko // See https://llvm.org/LICENSE.txt for license information.
591f11611SOleksandr "Alex" Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
691f11611SOleksandr "Alex" Zinenko //
791f11611SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
891f11611SOleksandr "Alex" Zinenko //
991f11611SOleksandr "Alex" Zinenko // Pybind classes for the transform dialect interpreter.
1091f11611SOleksandr "Alex" Zinenko //
1191f11611SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
1291f11611SOleksandr "Alex" Zinenko 
1391f11611SOleksandr "Alex" Zinenko #include "mlir-c/Dialect/Transform/Interpreter.h"
1491f11611SOleksandr "Alex" Zinenko #include "mlir-c/IR.h"
1591f11611SOleksandr "Alex" Zinenko #include "mlir-c/Support.h"
16392622d0SMaksim Levental #include "mlir/Bindings/Python/Diagnostics.h"
17*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/NanobindAdaptors.h"
18*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h"
1991f11611SOleksandr "Alex" Zinenko 
20*5cd42747SPeter Hawkins namespace nb = nanobind;
2191f11611SOleksandr "Alex" Zinenko 
2291f11611SOleksandr "Alex" Zinenko namespace {
2391f11611SOleksandr "Alex" Zinenko struct PyMlirTransformOptions {
2491f11611SOleksandr "Alex" Zinenko   PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); };
250e5bde02SAdrian Kuegel   PyMlirTransformOptions(PyMlirTransformOptions &&other) {
2691f11611SOleksandr "Alex" Zinenko     options = other.options;
2791f11611SOleksandr "Alex" Zinenko     other.options.ptr = nullptr;
2891f11611SOleksandr "Alex" Zinenko   }
2991f11611SOleksandr "Alex" Zinenko   PyMlirTransformOptions(const PyMlirTransformOptions &) = delete;
3091f11611SOleksandr "Alex" Zinenko 
3191f11611SOleksandr "Alex" Zinenko   ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); }
3291f11611SOleksandr "Alex" Zinenko 
3391f11611SOleksandr "Alex" Zinenko   MlirTransformOptions options;
3491f11611SOleksandr "Alex" Zinenko };
3591f11611SOleksandr "Alex" Zinenko } // namespace
3691f11611SOleksandr "Alex" Zinenko 
37*5cd42747SPeter Hawkins static void populateTransformInterpreterSubmodule(nb::module_ &m) {
38*5cd42747SPeter Hawkins   nb::class_<PyMlirTransformOptions>(m, "TransformOptions")
39*5cd42747SPeter Hawkins       .def(nb::init<>())
40*5cd42747SPeter Hawkins       .def_prop_rw(
4191f11611SOleksandr "Alex" Zinenko           "expensive_checks",
4291f11611SOleksandr "Alex" Zinenko           [](const PyMlirTransformOptions &self) {
4391f11611SOleksandr "Alex" Zinenko             return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
4491f11611SOleksandr "Alex" Zinenko           },
4591f11611SOleksandr "Alex" Zinenko           [](PyMlirTransformOptions &self, bool value) {
4691f11611SOleksandr "Alex" Zinenko             mlirTransformOptionsEnableExpensiveChecks(self.options, value);
4791f11611SOleksandr "Alex" Zinenko           })
48*5cd42747SPeter Hawkins       .def_prop_rw(
4991f11611SOleksandr "Alex" Zinenko           "enforce_single_top_level_transform_op",
5091f11611SOleksandr "Alex" Zinenko           [](const PyMlirTransformOptions &self) {
5191f11611SOleksandr "Alex" Zinenko             return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
5291f11611SOleksandr "Alex" Zinenko                 self.options);
5391f11611SOleksandr "Alex" Zinenko           },
5491f11611SOleksandr "Alex" Zinenko           [](PyMlirTransformOptions &self, bool value) {
5591f11611SOleksandr "Alex" Zinenko             mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
5691f11611SOleksandr "Alex" Zinenko                                                                  value);
5791f11611SOleksandr "Alex" Zinenko           });
5891f11611SOleksandr "Alex" Zinenko 
5991f11611SOleksandr "Alex" Zinenko   m.def(
6091f11611SOleksandr "Alex" Zinenko       "apply_named_sequence",
6191f11611SOleksandr "Alex" Zinenko       [](MlirOperation payloadRoot, MlirOperation transformRoot,
6291f11611SOleksandr "Alex" Zinenko          MlirOperation transformModule, const PyMlirTransformOptions &options) {
6391f11611SOleksandr "Alex" Zinenko         mlir::python::CollectDiagnosticsToStringScope scope(
6491f11611SOleksandr "Alex" Zinenko             mlirOperationGetContext(transformRoot));
6591f11611SOleksandr "Alex" Zinenko 
6691f11611SOleksandr "Alex" Zinenko         // Calling back into Python to invalidate everything under the payload
6791f11611SOleksandr "Alex" Zinenko         // root. This is awkward, but we don't have access to PyMlirContext
6891f11611SOleksandr "Alex" Zinenko         // object here otherwise.
69*5cd42747SPeter Hawkins         nb::object obj = nb::cast(payloadRoot);
7091f11611SOleksandr "Alex" Zinenko         obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
7191f11611SOleksandr "Alex" Zinenko 
7291f11611SOleksandr "Alex" Zinenko         MlirLogicalResult result = mlirTransformApplyNamedSequence(
7391f11611SOleksandr "Alex" Zinenko             payloadRoot, transformRoot, transformModule, options.options);
7491f11611SOleksandr "Alex" Zinenko         if (mlirLogicalResultIsSuccess(result))
7591f11611SOleksandr "Alex" Zinenko           return;
7691f11611SOleksandr "Alex" Zinenko 
77*5cd42747SPeter Hawkins         throw nb::value_error(
78*5cd42747SPeter Hawkins             ("Failed to apply named transform sequence.\nDiagnostic message " +
79*5cd42747SPeter Hawkins              scope.takeMessage())
80*5cd42747SPeter Hawkins                 .c_str());
8191f11611SOleksandr "Alex" Zinenko       },
82*5cd42747SPeter Hawkins       nb::arg("payload_root"), nb::arg("transform_root"),
83*5cd42747SPeter Hawkins       nb::arg("transform_module"),
84*5cd42747SPeter Hawkins       nb::arg("transform_options") = PyMlirTransformOptions());
8573140daeSOleksandr "Alex" Zinenko 
8673140daeSOleksandr "Alex" Zinenko   m.def(
8773140daeSOleksandr "Alex" Zinenko       "copy_symbols_and_merge_into",
8873140daeSOleksandr "Alex" Zinenko       [](MlirOperation target, MlirOperation other) {
8973140daeSOleksandr "Alex" Zinenko         mlir::python::CollectDiagnosticsToStringScope scope(
9073140daeSOleksandr "Alex" Zinenko             mlirOperationGetContext(target));
9173140daeSOleksandr "Alex" Zinenko 
9273140daeSOleksandr "Alex" Zinenko         MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
9373140daeSOleksandr "Alex" Zinenko         if (mlirLogicalResultIsFailure(result)) {
94*5cd42747SPeter Hawkins           throw nb::value_error(
95*5cd42747SPeter Hawkins               ("Failed to merge symbols.\nDiagnostic message " +
96*5cd42747SPeter Hawkins                scope.takeMessage())
97*5cd42747SPeter Hawkins                   .c_str());
9873140daeSOleksandr "Alex" Zinenko         }
9973140daeSOleksandr "Alex" Zinenko       },
100*5cd42747SPeter Hawkins       nb::arg("target"), nb::arg("other"));
10191f11611SOleksandr "Alex" Zinenko }
10291f11611SOleksandr "Alex" Zinenko 
103*5cd42747SPeter Hawkins NB_MODULE(_mlirTransformInterpreter, m) {
10491f11611SOleksandr "Alex" Zinenko   m.doc() = "MLIR Transform dialect interpreter functionality.";
10591f11611SOleksandr "Alex" Zinenko   populateTransformInterpreterSubmodule(m);
10691f11611SOleksandr "Alex" Zinenko }
107