xref: /llvm-project/mlir/python/mlir/dialects/pdl.py (revision 18cf1cd92b554ba0b870c6a2223ea4d0d3c6dd21)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from ._pdl_ops_gen import *
6from ._pdl_ops_gen import _Dialect
7from .._mlir_libs._mlirDialectsPDL import *
8from .._mlir_libs._mlirDialectsPDL import OperationType
9from ..extras.meta import region_op
10
11try:
12    from ..ir import *
13    from ..dialects import pdl
14except ImportError as e:
15    raise RuntimeError("Error loading imports from extension module") from e
16
17from typing import Union, Optional, Sequence, Mapping, NewType
18from ._ods_common import (
19    get_op_result_or_value as _get_value,
20    get_op_results_or_values as _get_values,
21    _cext as _ods_cext,
22)
23
24
25@_ods_cext.register_operation(_Dialect, replace=True)
26class AttributeOp(AttributeOp):
27    """Specialization for PDL attribute op class."""
28
29    def __init__(
30        self,
31        valueType: Optional[Union[OpView, Operation, Value]] = None,
32        value: Optional[Attribute] = None,
33        *,
34        loc=None,
35        ip=None,
36    ):
37        valueType = valueType if valueType is None else _get_value(valueType)
38        result = pdl.AttributeType.get()
39        super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
40
41
42@_ods_cext.register_operation(_Dialect, replace=True)
43class OperandOp(OperandOp):
44    """Specialization for PDL operand op class."""
45
46    def __init__(
47        self,
48        type: Optional[Union[OpView, Operation, Value]] = None,
49        *,
50        loc=None,
51        ip=None,
52    ):
53        type = type if type is None else _get_value(type)
54        result = pdl.ValueType.get()
55        super().__init__(result, valueType=type, loc=loc, ip=ip)
56
57
58@_ods_cext.register_operation(_Dialect, replace=True)
59class OperandsOp(OperandsOp):
60    """Specialization for PDL operands op class."""
61
62    def __init__(
63        self,
64        types: Optional[Union[OpView, Operation, Value]] = None,
65        *,
66        loc=None,
67        ip=None,
68    ):
69        types = types if types is None else _get_value(types)
70        result = pdl.RangeType.get(pdl.ValueType.get())
71        super().__init__(result, valueType=types, loc=loc, ip=ip)
72
73
74@_ods_cext.register_operation(_Dialect, replace=True)
75class OperationOp(OperationOp):
76    """Specialization for PDL operand op class."""
77
78    def __init__(
79        self,
80        name: Optional[Union[str, StringAttr]] = None,
81        args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
82        attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None,
83        types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
84        *,
85        loc=None,
86        ip=None,
87    ):
88        if types is None:
89            types = []
90        if attributes is None:
91            attributes = {}
92        if args is None:
93            args = []
94        args = _get_values(args)
95        attrNames = []
96        attrValues = []
97        for attrName, attrValue in attributes.items():
98            attrNames.append(StringAttr.get(attrName))
99            attrValues.append(_get_value(attrValue))
100        attrNames = ArrayAttr.get(attrNames)
101        types = _get_values(types)
102        result = pdl.OperationType.get()
103        super().__init__(
104            result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip
105        )
106
107
108@_ods_cext.register_operation(_Dialect, replace=True)
109class PatternOp(PatternOp):
110    """Specialization for PDL pattern op class."""
111
112    def __init__(
113        self,
114        benefit: Union[IntegerAttr, int],
115        name: Optional[Union[StringAttr, str]] = None,
116        *,
117        loc=None,
118        ip=None,
119    ):
120        """Creates an PDL `pattern` operation."""
121        super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
122        self.regions[0].blocks.append()
123
124    @property
125    def body(self):
126        """Return the body (block) of the pattern."""
127        return self.regions[0].blocks[0]
128
129
130pattern = region_op(PatternOp.__base__)
131
132
133@_ods_cext.register_operation(_Dialect, replace=True)
134class ReplaceOp(ReplaceOp):
135    """Specialization for PDL replace op class."""
136
137    def __init__(
138        self,
139        op: Union[OpView, Operation, Value],
140        *,
141        with_op: Optional[Union[OpView, Operation, Value]] = None,
142        with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
143        loc=None,
144        ip=None,
145    ):
146        if with_values is None:
147            with_values = []
148        op = _get_value(op)
149        with_op = with_op if with_op is None else _get_value(with_op)
150        with_values = _get_values(with_values)
151        super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
152
153
154@_ods_cext.register_operation(_Dialect, replace=True)
155class ResultOp(ResultOp):
156    """Specialization for PDL result op class."""
157
158    def __init__(
159        self,
160        parent: Union[OpView, Operation, Value],
161        index: Union[IntegerAttr, int],
162        *,
163        loc=None,
164        ip=None,
165    ):
166        parent = _get_value(parent)
167        result = pdl.ValueType.get()
168        super().__init__(result, parent, index, loc=loc, ip=ip)
169
170
171@_ods_cext.register_operation(_Dialect, replace=True)
172class RewriteOp(RewriteOp):
173    """Specialization for PDL rewrite op class."""
174
175    def __init__(
176        self,
177        root: Optional[Union[OpView, Operation, Value]] = None,
178        name: Optional[Union[StringAttr, str]] = None,
179        args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
180        *,
181        loc=None,
182        ip=None,
183    ):
184        if args is None:
185            args = []
186        root = root if root is None else _get_value(root)
187        args = _get_values(args)
188        super().__init__(args, root=root, name=name, loc=loc, ip=ip)
189
190    def add_body(self):
191        """Add body (block) to the rewrite."""
192        self.regions[0].blocks.append()
193        return self.body
194
195    @property
196    def body(self):
197        """Return the body (block) of the rewrite."""
198        return self.regions[0].blocks[0]
199
200
201rewrite = region_op(RewriteOp)
202
203
204@_ods_cext.register_operation(_Dialect, replace=True)
205class TypeOp(TypeOp):
206    """Specialization for PDL type op class."""
207
208    def __init__(
209        self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None
210    ):
211        result = pdl.TypeType.get()
212        super().__init__(result, constantType=constantType, loc=loc, ip=ip)
213
214
215@_ods_cext.register_operation(_Dialect, replace=True)
216class TypesOp(TypesOp):
217    """Specialization for PDL types op class."""
218
219    def __init__(
220        self,
221        constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
222        *,
223        loc=None,
224        ip=None,
225    ):
226        if constantTypes is None:
227            constantTypes = []
228        result = pdl.RangeType.get(pdl.TypeType.get())
229        super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
230
231
232OperationTypeT = NewType("OperationType", OperationType)
233
234
235def op_t() -> OperationTypeT:
236    return OperationTypeT(OperationType.get())
237