xref: /llvm-project/mlir/python/mlir/dialects/transform/__init__.py (revision 537b2aa264c5a9879a80289c8d123b39e520eb15)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from .._transform_enum_gen import *
6from .._transform_ops_gen import *
7from .._transform_ops_gen import _Dialect
8from ..._mlir_libs._mlirDialectsTransform import *
9from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType
10
11try:
12    from ...ir import *
13    from .._ods_common import (
14        get_op_result_or_value as _get_op_result_or_value,
15        get_op_results_or_values as _get_op_results_or_values,
16        _cext as _ods_cext,
17    )
18except ImportError as e:
19    raise RuntimeError("Error loading imports from extension module") from e
20
21from typing import Optional, Sequence, Union, NewType
22
23
24@_ods_cext.register_operation(_Dialect, replace=True)
25class CastOp(CastOp):
26    def __init__(
27        self,
28        result_type: Type,
29        target: Union[Operation, Value],
30        *,
31        loc=None,
32        ip=None,
33    ):
34        super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
35
36
37@_ods_cext.register_operation(_Dialect, replace=True)
38class ApplyPatternsOp(ApplyPatternsOp):
39    def __init__(
40        self,
41        target: Union[Operation, Value, OpView],
42        *,
43        loc=None,
44        ip=None,
45    ):
46        super().__init__(target, loc=loc, ip=ip)
47        self.regions[0].blocks.append()
48
49    @property
50    def patterns(self) -> Block:
51        return self.regions[0].blocks[0]
52
53
54@_ods_cext.register_operation(_Dialect, replace=True)
55class GetParentOp(GetParentOp):
56    def __init__(
57        self,
58        result_type: Type,
59        target: Union[Operation, Value],
60        *,
61        isolated_from_above: bool = False,
62        op_name: Optional[str] = None,
63        deduplicate: bool = False,
64        nth_parent: int = 1,
65        loc=None,
66        ip=None,
67    ):
68        super().__init__(
69            result_type,
70            _get_op_result_or_value(target),
71            isolated_from_above=isolated_from_above,
72            op_name=op_name,
73            deduplicate=deduplicate,
74            nth_parent=nth_parent,
75            loc=loc,
76            ip=ip,
77        )
78
79
80@_ods_cext.register_operation(_Dialect, replace=True)
81class MergeHandlesOp(MergeHandlesOp):
82    def __init__(
83        self,
84        handles: Sequence[Union[Operation, Value]],
85        *,
86        deduplicate: bool = False,
87        loc=None,
88        ip=None,
89    ):
90        super().__init__(
91            [_get_op_result_or_value(h) for h in handles],
92            deduplicate=deduplicate,
93            loc=loc,
94            ip=ip,
95        )
96
97
98@_ods_cext.register_operation(_Dialect, replace=True)
99class ReplicateOp(ReplicateOp):
100    def __init__(
101        self,
102        pattern: Union[Operation, Value],
103        handles: Sequence[Union[Operation, Value]],
104        *,
105        loc=None,
106        ip=None,
107    ):
108        super().__init__(
109            [_get_op_result_or_value(h).type for h in handles],
110            _get_op_result_or_value(pattern),
111            [_get_op_result_or_value(h) for h in handles],
112            loc=loc,
113            ip=ip,
114        )
115
116
117@_ods_cext.register_operation(_Dialect, replace=True)
118class SequenceOp(SequenceOp):
119    def __init__(
120        self,
121        failure_propagation_mode,
122        results: Sequence[Type],
123        target: Union[Operation, Value, Type],
124        extra_bindings: Optional[
125            Union[Sequence[Value], Sequence[Type], Operation, OpView]
126        ] = None,
127    ):
128        root = (
129            _get_op_result_or_value(target)
130            if isinstance(target, (Operation, Value))
131            else None
132        )
133        root_type = root.type if not isinstance(target, Type) else target
134
135        if extra_bindings is None:
136            extra_bindings = []
137        if isinstance(extra_bindings, (Operation, OpView)):
138            extra_bindings = _get_op_results_or_values(extra_bindings)
139
140        extra_binding_types = []
141        if len(extra_bindings) != 0:
142            if isinstance(extra_bindings[0], Type):
143                extra_binding_types = extra_bindings
144                extra_bindings = []
145            else:
146                extra_binding_types = [v.type for v in extra_bindings]
147
148        super().__init__(
149            results_=results,
150            failure_propagation_mode=failure_propagation_mode,
151            root=root,
152            extra_bindings=extra_bindings,
153        )
154        self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
155
156    @property
157    def body(self) -> Block:
158        return self.regions[0].blocks[0]
159
160    @property
161    def bodyTarget(self) -> Value:
162        return self.body.arguments[0]
163
164    @property
165    def bodyExtraArgs(self) -> BlockArgumentList:
166        return self.body.arguments[1:]
167
168
169@_ods_cext.register_operation(_Dialect, replace=True)
170class NamedSequenceOp(NamedSequenceOp):
171    def __init__(
172        self,
173        sym_name,
174        input_types: Sequence[Type],
175        result_types: Sequence[Type],
176        sym_visibility=None,
177        arg_attrs=None,
178        res_attrs=None,
179    ):
180        function_type = FunctionType.get(input_types, result_types)
181        super().__init__(
182            sym_name=sym_name,
183            function_type=TypeAttr.get(function_type),
184            sym_visibility=sym_visibility,
185            arg_attrs=arg_attrs,
186            res_attrs=res_attrs,
187        )
188        self.regions[0].blocks.append(*input_types)
189
190    @property
191    def body(self) -> Block:
192        return self.regions[0].blocks[0]
193
194    @property
195    def bodyTarget(self) -> Value:
196        return self.body.arguments[0]
197
198    @property
199    def bodyExtraArgs(self) -> BlockArgumentList:
200        return self.body.arguments[1:]
201
202
203@_ods_cext.register_operation(_Dialect, replace=True)
204class YieldOp(YieldOp):
205    def __init__(
206        self,
207        operands: Optional[Union[Operation, Sequence[Value]]] = None,
208        *,
209        loc=None,
210        ip=None,
211    ):
212        if operands is None:
213            operands = []
214        super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
215
216
217AnyOpTypeT = NewType("AnyOpType", AnyOpType)
218
219
220def any_op_t() -> AnyOpTypeT:
221    return AnyOpTypeT(AnyOpType.get())
222