1 //===- TransformInterpreter.cpp -------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pybind classes for the transform dialect interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir-c/Dialect/Transform/Interpreter.h" 14 #include "mlir-c/IR.h" 15 #include "mlir-c/Support.h" 16 #include "mlir/Bindings/Python/PybindAdaptors.h" 17 18 #include <pybind11/detail/common.h> 19 #include <pybind11/pybind11.h> 20 21 namespace py = pybind11; 22 23 namespace { 24 struct PyMlirTransformOptions { 25 PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); }; 26 PyMlirTransformOptions(PyMlirTransformOptions &&other) { 27 options = other.options; 28 other.options.ptr = nullptr; 29 } 30 PyMlirTransformOptions(const PyMlirTransformOptions &) = delete; 31 32 ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); } 33 34 MlirTransformOptions options; 35 }; 36 } // namespace 37 38 static void populateTransformInterpreterSubmodule(py::module &m) { 39 py::class_<PyMlirTransformOptions>(m, "TransformOptions", py::module_local()) 40 .def(py::init()) 41 .def_property( 42 "expensive_checks", 43 [](const PyMlirTransformOptions &self) { 44 return mlirTransformOptionsGetExpensiveChecksEnabled(self.options); 45 }, 46 [](PyMlirTransformOptions &self, bool value) { 47 mlirTransformOptionsEnableExpensiveChecks(self.options, value); 48 }) 49 .def_property( 50 "enforce_single_top_level_transform_op", 51 [](const PyMlirTransformOptions &self) { 52 return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( 53 self.options); 54 }, 55 [](PyMlirTransformOptions &self, bool value) { 56 mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options, 57 value); 58 }); 59 60 m.def( 61 "apply_named_sequence", 62 [](MlirOperation payloadRoot, MlirOperation transformRoot, 63 MlirOperation transformModule, const PyMlirTransformOptions &options) { 64 mlir::python::CollectDiagnosticsToStringScope scope( 65 mlirOperationGetContext(transformRoot)); 66 67 // Calling back into Python to invalidate everything under the payload 68 // root. This is awkward, but we don't have access to PyMlirContext 69 // object here otherwise. 70 py::object obj = py::cast(payloadRoot); 71 obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot); 72 73 MlirLogicalResult result = mlirTransformApplyNamedSequence( 74 payloadRoot, transformRoot, transformModule, options.options); 75 if (mlirLogicalResultIsSuccess(result)) 76 return; 77 78 throw py::value_error( 79 "Failed to apply named transform sequence.\nDiagnostic message " + 80 scope.takeMessage()); 81 }, 82 py::arg("payload_root"), py::arg("transform_root"), 83 py::arg("transform_module"), 84 py::arg("transform_options") = PyMlirTransformOptions()); 85 86 m.def( 87 "copy_symbols_and_merge_into", 88 [](MlirOperation target, MlirOperation other) { 89 mlir::python::CollectDiagnosticsToStringScope scope( 90 mlirOperationGetContext(target)); 91 92 MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other); 93 if (mlirLogicalResultIsFailure(result)) { 94 throw py::value_error( 95 "Failed to merge symbols.\nDiagnostic message " + 96 scope.takeMessage()); 97 } 98 }, 99 py::arg("target"), py::arg("other")); 100 } 101 102 PYBIND11_MODULE(_mlirTransformInterpreter, m) { 103 m.doc() = "MLIR Transform dialect interpreter functionality."; 104 populateTransformInterpreterSubmodule(m); 105 } 106