1 //===- TransformInterpreter.cpp -------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pybind classes for the transform dialect interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir-c/Dialect/Transform/Interpreter.h" 14 #include "mlir-c/IR.h" 15 #include "mlir-c/Support.h" 16 #include "mlir/Bindings/Python/Diagnostics.h" 17 #include "mlir/Bindings/Python/NanobindAdaptors.h" 18 #include "mlir/Bindings/Python/Nanobind.h" 19 20 namespace nb = nanobind; 21 22 namespace { 23 struct PyMlirTransformOptions { 24 PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); }; 25 PyMlirTransformOptions(PyMlirTransformOptions &&other) { 26 options = other.options; 27 other.options.ptr = nullptr; 28 } 29 PyMlirTransformOptions(const PyMlirTransformOptions &) = delete; 30 31 ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); } 32 33 MlirTransformOptions options; 34 }; 35 } // namespace 36 37 static void populateTransformInterpreterSubmodule(nb::module_ &m) { 38 nb::class_<PyMlirTransformOptions>(m, "TransformOptions") 39 .def(nb::init<>()) 40 .def_prop_rw( 41 "expensive_checks", 42 [](const PyMlirTransformOptions &self) { 43 return mlirTransformOptionsGetExpensiveChecksEnabled(self.options); 44 }, 45 [](PyMlirTransformOptions &self, bool value) { 46 mlirTransformOptionsEnableExpensiveChecks(self.options, value); 47 }) 48 .def_prop_rw( 49 "enforce_single_top_level_transform_op", 50 [](const PyMlirTransformOptions &self) { 51 return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( 52 self.options); 53 }, 54 [](PyMlirTransformOptions &self, bool value) { 55 mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options, 56 value); 57 }); 58 59 m.def( 60 "apply_named_sequence", 61 [](MlirOperation payloadRoot, MlirOperation transformRoot, 62 MlirOperation transformModule, const PyMlirTransformOptions &options) { 63 mlir::python::CollectDiagnosticsToStringScope scope( 64 mlirOperationGetContext(transformRoot)); 65 66 // Calling back into Python to invalidate everything under the payload 67 // root. This is awkward, but we don't have access to PyMlirContext 68 // object here otherwise. 69 nb::object obj = nb::cast(payloadRoot); 70 obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot); 71 72 MlirLogicalResult result = mlirTransformApplyNamedSequence( 73 payloadRoot, transformRoot, transformModule, options.options); 74 if (mlirLogicalResultIsSuccess(result)) 75 return; 76 77 throw nb::value_error( 78 ("Failed to apply named transform sequence.\nDiagnostic message " + 79 scope.takeMessage()) 80 .c_str()); 81 }, 82 nb::arg("payload_root"), nb::arg("transform_root"), 83 nb::arg("transform_module"), 84 nb::arg("transform_options") = PyMlirTransformOptions()); 85 86 m.def( 87 "copy_symbols_and_merge_into", 88 [](MlirOperation target, MlirOperation other) { 89 mlir::python::CollectDiagnosticsToStringScope scope( 90 mlirOperationGetContext(target)); 91 92 MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other); 93 if (mlirLogicalResultIsFailure(result)) { 94 throw nb::value_error( 95 ("Failed to merge symbols.\nDiagnostic message " + 96 scope.takeMessage()) 97 .c_str()); 98 } 99 }, 100 nb::arg("target"), nb::arg("other")); 101 } 102 103 NB_MODULE(_mlirTransformInterpreter, m) { 104 m.doc() = "MLIR Transform dialect interpreter functionality."; 105 populateTransformInterpreterSubmodule(m); 106 } 107