xref: /llvm-project/mlir/python/mlir/dialects/transform/interpreter/__init__.py (revision ff57f40673f0db2c1a867e5697d5407bc9f39a5e)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from ....ir import Operation
6from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter
7
8TransformOptions = _cextTransformInterpreter.TransformOptions
9
10
11def _unpack_operation(op):
12    if isinstance(op, Operation):
13        return op
14    return op.operation
15
16
17def apply_named_sequence(
18    payload_root, transform_root, transform_module, transform_options=None
19):
20    """Applies the transformation script starting at the given transform root
21    operation to the given payload operation. The module containing the
22    transform root as well as the transform options should be provided.
23    The transform operation must implement TransformOpInterface and the module
24    must be a ModuleOp."""
25
26    args = tuple(
27        map(_unpack_operation, (payload_root, transform_root, transform_module))
28    )
29    if transform_options is None:
30        _cextTransformInterpreter.apply_named_sequence(*args)
31    else:
32        _cextTransformInterpreter.apply_named_sequence(*args, transform_options)
33
34
35def copy_symbols_and_merge_into(target, other):
36    """Copies symbols from other into target, renaming private symbols to avoid
37    duplicates. Raises an error if copying would lead to duplicate public
38    symbols."""
39    _cextTransformInterpreter.copy_symbols_and_merge_into(
40        _unpack_operation(target), _unpack_operation(other)
41    )
42