191f11611SOleksandr "Alex" Zinenko //===- TransformInterpreter.cpp -------------------------------------------===// 291f11611SOleksandr "Alex" Zinenko // 391f11611SOleksandr "Alex" Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 491f11611SOleksandr "Alex" Zinenko // See https://llvm.org/LICENSE.txt for license information. 591f11611SOleksandr "Alex" Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 691f11611SOleksandr "Alex" Zinenko // 791f11611SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 891f11611SOleksandr "Alex" Zinenko // 991f11611SOleksandr "Alex" Zinenko // Pybind classes for the transform dialect interpreter. 1091f11611SOleksandr "Alex" Zinenko // 1191f11611SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 1291f11611SOleksandr "Alex" Zinenko 1391f11611SOleksandr "Alex" Zinenko #include "mlir-c/Dialect/Transform/Interpreter.h" 1491f11611SOleksandr "Alex" Zinenko #include "mlir-c/IR.h" 1591f11611SOleksandr "Alex" Zinenko #include "mlir-c/Support.h" 16392622d0SMaksim Levental #include "mlir/Bindings/Python/Diagnostics.h" 17*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/NanobindAdaptors.h" 18*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h" 1991f11611SOleksandr "Alex" Zinenko 20*5cd42747SPeter Hawkins namespace nb = nanobind; 2191f11611SOleksandr "Alex" Zinenko 2291f11611SOleksandr "Alex" Zinenko namespace { 2391f11611SOleksandr "Alex" Zinenko struct PyMlirTransformOptions { 2491f11611SOleksandr "Alex" Zinenko PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); }; 250e5bde02SAdrian Kuegel PyMlirTransformOptions(PyMlirTransformOptions &&other) { 2691f11611SOleksandr "Alex" Zinenko options = other.options; 2791f11611SOleksandr "Alex" Zinenko other.options.ptr = nullptr; 2891f11611SOleksandr "Alex" Zinenko } 2991f11611SOleksandr "Alex" Zinenko PyMlirTransformOptions(const PyMlirTransformOptions &) = delete; 3091f11611SOleksandr "Alex" Zinenko 3191f11611SOleksandr "Alex" Zinenko ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); } 3291f11611SOleksandr "Alex" Zinenko 3391f11611SOleksandr "Alex" Zinenko MlirTransformOptions options; 3491f11611SOleksandr "Alex" Zinenko }; 3591f11611SOleksandr "Alex" Zinenko } // namespace 3691f11611SOleksandr "Alex" Zinenko 37*5cd42747SPeter Hawkins static void populateTransformInterpreterSubmodule(nb::module_ &m) { 38*5cd42747SPeter Hawkins nb::class_<PyMlirTransformOptions>(m, "TransformOptions") 39*5cd42747SPeter Hawkins .def(nb::init<>()) 40*5cd42747SPeter Hawkins .def_prop_rw( 4191f11611SOleksandr "Alex" Zinenko "expensive_checks", 4291f11611SOleksandr "Alex" Zinenko [](const PyMlirTransformOptions &self) { 4391f11611SOleksandr "Alex" Zinenko return mlirTransformOptionsGetExpensiveChecksEnabled(self.options); 4491f11611SOleksandr "Alex" Zinenko }, 4591f11611SOleksandr "Alex" Zinenko [](PyMlirTransformOptions &self, bool value) { 4691f11611SOleksandr "Alex" Zinenko mlirTransformOptionsEnableExpensiveChecks(self.options, value); 4791f11611SOleksandr "Alex" Zinenko }) 48*5cd42747SPeter Hawkins .def_prop_rw( 4991f11611SOleksandr "Alex" Zinenko "enforce_single_top_level_transform_op", 5091f11611SOleksandr "Alex" Zinenko [](const PyMlirTransformOptions &self) { 5191f11611SOleksandr "Alex" Zinenko return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( 5291f11611SOleksandr "Alex" Zinenko self.options); 5391f11611SOleksandr "Alex" Zinenko }, 5491f11611SOleksandr "Alex" Zinenko [](PyMlirTransformOptions &self, bool value) { 5591f11611SOleksandr "Alex" Zinenko mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options, 5691f11611SOleksandr "Alex" Zinenko value); 5791f11611SOleksandr "Alex" Zinenko }); 5891f11611SOleksandr "Alex" Zinenko 5991f11611SOleksandr "Alex" Zinenko m.def( 6091f11611SOleksandr "Alex" Zinenko "apply_named_sequence", 6191f11611SOleksandr "Alex" Zinenko [](MlirOperation payloadRoot, MlirOperation transformRoot, 6291f11611SOleksandr "Alex" Zinenko MlirOperation transformModule, const PyMlirTransformOptions &options) { 6391f11611SOleksandr "Alex" Zinenko mlir::python::CollectDiagnosticsToStringScope scope( 6491f11611SOleksandr "Alex" Zinenko mlirOperationGetContext(transformRoot)); 6591f11611SOleksandr "Alex" Zinenko 6691f11611SOleksandr "Alex" Zinenko // Calling back into Python to invalidate everything under the payload 6791f11611SOleksandr "Alex" Zinenko // root. This is awkward, but we don't have access to PyMlirContext 6891f11611SOleksandr "Alex" Zinenko // object here otherwise. 69*5cd42747SPeter Hawkins nb::object obj = nb::cast(payloadRoot); 7091f11611SOleksandr "Alex" Zinenko obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot); 7191f11611SOleksandr "Alex" Zinenko 7291f11611SOleksandr "Alex" Zinenko MlirLogicalResult result = mlirTransformApplyNamedSequence( 7391f11611SOleksandr "Alex" Zinenko payloadRoot, transformRoot, transformModule, options.options); 7491f11611SOleksandr "Alex" Zinenko if (mlirLogicalResultIsSuccess(result)) 7591f11611SOleksandr "Alex" Zinenko return; 7691f11611SOleksandr "Alex" Zinenko 77*5cd42747SPeter Hawkins throw nb::value_error( 78*5cd42747SPeter Hawkins ("Failed to apply named transform sequence.\nDiagnostic message " + 79*5cd42747SPeter Hawkins scope.takeMessage()) 80*5cd42747SPeter Hawkins .c_str()); 8191f11611SOleksandr "Alex" Zinenko }, 82*5cd42747SPeter Hawkins nb::arg("payload_root"), nb::arg("transform_root"), 83*5cd42747SPeter Hawkins nb::arg("transform_module"), 84*5cd42747SPeter Hawkins nb::arg("transform_options") = PyMlirTransformOptions()); 8573140daeSOleksandr "Alex" Zinenko 8673140daeSOleksandr "Alex" Zinenko m.def( 8773140daeSOleksandr "Alex" Zinenko "copy_symbols_and_merge_into", 8873140daeSOleksandr "Alex" Zinenko [](MlirOperation target, MlirOperation other) { 8973140daeSOleksandr "Alex" Zinenko mlir::python::CollectDiagnosticsToStringScope scope( 9073140daeSOleksandr "Alex" Zinenko mlirOperationGetContext(target)); 9173140daeSOleksandr "Alex" Zinenko 9273140daeSOleksandr "Alex" Zinenko MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other); 9373140daeSOleksandr "Alex" Zinenko if (mlirLogicalResultIsFailure(result)) { 94*5cd42747SPeter Hawkins throw nb::value_error( 95*5cd42747SPeter Hawkins ("Failed to merge symbols.\nDiagnostic message " + 96*5cd42747SPeter Hawkins scope.takeMessage()) 97*5cd42747SPeter Hawkins .c_str()); 9873140daeSOleksandr "Alex" Zinenko } 9973140daeSOleksandr "Alex" Zinenko }, 100*5cd42747SPeter Hawkins nb::arg("target"), nb::arg("other")); 10191f11611SOleksandr "Alex" Zinenko } 10291f11611SOleksandr "Alex" Zinenko 103*5cd42747SPeter Hawkins NB_MODULE(_mlirTransformInterpreter, m) { 10491f11611SOleksandr "Alex" Zinenko m.doc() = "MLIR Transform dialect interpreter functionality."; 10591f11611SOleksandr "Alex" Zinenko populateTransformInterpreterSubmodule(m); 10691f11611SOleksandr "Alex" Zinenko } 107