xref: /llvm-project/mlir/python/mlir/dialects/transform/pdl.py (revision a2288a8944c310fcad1196302f16513797e1fcbc)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from .._transform_pdl_extension_ops_gen import *
6from .._transform_pdl_extension_ops_gen import _Dialect
7
8try:
9    from ...ir import *
10    from .._ods_common import (
11        get_op_result_or_value as _get_op_result_or_value,
12        get_op_results_or_values as _get_op_results_or_values,
13        _cext as _ods_cext,
14    )
15except ImportError as e:
16    raise RuntimeError("Error loading imports from extension module") from e
17
18from typing import Union
19
20
21@_ods_cext.register_operation(_Dialect, replace=True)
22class PDLMatchOp(PDLMatchOp):
23    def __init__(
24        self,
25        result_type: Type,
26        target: Union[Operation, Value],
27        pattern_name: Union[Attribute, str],
28        *,
29        loc=None,
30        ip=None,
31    ):
32        super().__init__(
33            result_type,
34            _get_op_result_or_value(target),
35            pattern_name,
36            loc=loc,
37            ip=ip,
38        )
39
40
41@_ods_cext.register_operation(_Dialect, replace=True)
42class WithPDLPatternsOp(WithPDLPatternsOp):
43    def __init__(self, target: Union[Operation, Value, Type], *, loc=None, ip=None):
44        root = _get_op_result_or_value(target) if not isinstance(target, Type) else None
45        root_type = target if isinstance(target, Type) else root.type
46        super().__init__(root=root, loc=loc, ip=ip)
47        self.regions[0].blocks.append(root_type)
48
49    @property
50    def body(self) -> Block:
51        return self.regions[0].blocks[0]
52
53    @property
54    def bodyTarget(self) -> Value:
55        return self.body.arguments[0]
56