1 //===- TransformInterpreter.cpp -------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pybind classes for the transform dialect interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <pybind11/detail/common.h> 14 #include <pybind11/pybind11.h> 15 16 #include "mlir-c/Dialect/Transform/Interpreter.h" 17 #include "mlir-c/IR.h" 18 #include "mlir-c/Support.h" 19 #include "mlir/Bindings/Python/Diagnostics.h" 20 #include "mlir/Bindings/Python/PybindAdaptors.h" 21 22 namespace py = pybind11; 23 24 namespace { 25 struct PyMlirTransformOptions { 26 PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); }; 27 PyMlirTransformOptions(PyMlirTransformOptions &&other) { 28 options = other.options; 29 other.options.ptr = nullptr; 30 } 31 PyMlirTransformOptions(const PyMlirTransformOptions &) = delete; 32 33 ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); } 34 35 MlirTransformOptions options; 36 }; 37 } // namespace 38 39 static void populateTransformInterpreterSubmodule(py::module &m) { 40 py::class_<PyMlirTransformOptions>(m, "TransformOptions", py::module_local()) 41 .def(py::init()) 42 .def_property( 43 "expensive_checks", 44 [](const PyMlirTransformOptions &self) { 45 return mlirTransformOptionsGetExpensiveChecksEnabled(self.options); 46 }, 47 [](PyMlirTransformOptions &self, bool value) { 48 mlirTransformOptionsEnableExpensiveChecks(self.options, value); 49 }) 50 .def_property( 51 "enforce_single_top_level_transform_op", 52 [](const PyMlirTransformOptions &self) { 53 return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( 54 self.options); 55 }, 56 [](PyMlirTransformOptions &self, bool value) { 57 mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options, 58 value); 59 }); 60 61 m.def( 62 "apply_named_sequence", 63 [](MlirOperation payloadRoot, MlirOperation transformRoot, 64 MlirOperation transformModule, const PyMlirTransformOptions &options) { 65 mlir::python::CollectDiagnosticsToStringScope scope( 66 mlirOperationGetContext(transformRoot)); 67 68 // Calling back into Python to invalidate everything under the payload 69 // root. This is awkward, but we don't have access to PyMlirContext 70 // object here otherwise. 71 py::object obj = py::cast(payloadRoot); 72 obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot); 73 74 MlirLogicalResult result = mlirTransformApplyNamedSequence( 75 payloadRoot, transformRoot, transformModule, options.options); 76 if (mlirLogicalResultIsSuccess(result)) 77 return; 78 79 throw py::value_error( 80 "Failed to apply named transform sequence.\nDiagnostic message " + 81 scope.takeMessage()); 82 }, 83 py::arg("payload_root"), py::arg("transform_root"), 84 py::arg("transform_module"), 85 py::arg("transform_options") = PyMlirTransformOptions()); 86 87 m.def( 88 "copy_symbols_and_merge_into", 89 [](MlirOperation target, MlirOperation other) { 90 mlir::python::CollectDiagnosticsToStringScope scope( 91 mlirOperationGetContext(target)); 92 93 MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other); 94 if (mlirLogicalResultIsFailure(result)) { 95 throw py::value_error( 96 "Failed to merge symbols.\nDiagnostic message " + 97 scope.takeMessage()); 98 } 99 }, 100 py::arg("target"), py::arg("other")); 101 } 102 103 PYBIND11_MODULE(_mlirTransformInterpreter, m) { 104 m.doc() = "MLIR Transform dialect interpreter functionality."; 105 populateTransformInterpreterSubmodule(m); 106 } 107