1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from ._pdl_ops_gen import * 6from ._pdl_ops_gen import _Dialect 7from .._mlir_libs._mlirDialectsPDL import * 8from .._mlir_libs._mlirDialectsPDL import OperationType 9from ..extras.meta import region_op 10 11try: 12 from ..ir import * 13 from ..dialects import pdl 14except ImportError as e: 15 raise RuntimeError("Error loading imports from extension module") from e 16 17from typing import Union, Optional, Sequence, Mapping, NewType 18from ._ods_common import ( 19 get_op_result_or_value as _get_value, 20 get_op_results_or_values as _get_values, 21 _cext as _ods_cext, 22) 23 24 25@_ods_cext.register_operation(_Dialect, replace=True) 26class AttributeOp(AttributeOp): 27 """Specialization for PDL attribute op class.""" 28 29 def __init__( 30 self, 31 valueType: Optional[Union[OpView, Operation, Value]] = None, 32 value: Optional[Attribute] = None, 33 *, 34 loc=None, 35 ip=None, 36 ): 37 valueType = valueType if valueType is None else _get_value(valueType) 38 result = pdl.AttributeType.get() 39 super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip) 40 41 42@_ods_cext.register_operation(_Dialect, replace=True) 43class OperandOp(OperandOp): 44 """Specialization for PDL operand op class.""" 45 46 def __init__( 47 self, 48 type: Optional[Union[OpView, Operation, Value]] = None, 49 *, 50 loc=None, 51 ip=None, 52 ): 53 type = type if type is None else _get_value(type) 54 result = pdl.ValueType.get() 55 super().__init__(result, valueType=type, loc=loc, ip=ip) 56 57 58@_ods_cext.register_operation(_Dialect, replace=True) 59class OperandsOp(OperandsOp): 60 """Specialization for PDL operands op class.""" 61 62 def __init__( 63 self, 64 types: Optional[Union[OpView, Operation, Value]] = None, 65 *, 66 loc=None, 67 ip=None, 68 ): 69 types = types if types is None else _get_value(types) 70 result = pdl.RangeType.get(pdl.ValueType.get()) 71 super().__init__(result, valueType=types, loc=loc, ip=ip) 72 73 74@_ods_cext.register_operation(_Dialect, replace=True) 75class OperationOp(OperationOp): 76 """Specialization for PDL operand op class.""" 77 78 def __init__( 79 self, 80 name: Optional[Union[str, StringAttr]] = None, 81 args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, 82 attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None, 83 types: Optional[Sequence[Union[OpView, Operation, Value]]] = None, 84 *, 85 loc=None, 86 ip=None, 87 ): 88 if types is None: 89 types = [] 90 if attributes is None: 91 attributes = {} 92 if args is None: 93 args = [] 94 args = _get_values(args) 95 attrNames = [] 96 attrValues = [] 97 for attrName, attrValue in attributes.items(): 98 attrNames.append(StringAttr.get(attrName)) 99 attrValues.append(_get_value(attrValue)) 100 attrNames = ArrayAttr.get(attrNames) 101 types = _get_values(types) 102 result = pdl.OperationType.get() 103 super().__init__( 104 result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip 105 ) 106 107 108@_ods_cext.register_operation(_Dialect, replace=True) 109class PatternOp(PatternOp): 110 """Specialization for PDL pattern op class.""" 111 112 def __init__( 113 self, 114 benefit: Union[IntegerAttr, int], 115 name: Optional[Union[StringAttr, str]] = None, 116 *, 117 loc=None, 118 ip=None, 119 ): 120 """Creates an PDL `pattern` operation.""" 121 super().__init__(benefit, sym_name=name, loc=loc, ip=ip) 122 self.regions[0].blocks.append() 123 124 @property 125 def body(self): 126 """Return the body (block) of the pattern.""" 127 return self.regions[0].blocks[0] 128 129 130pattern = region_op(PatternOp.__base__) 131 132 133@_ods_cext.register_operation(_Dialect, replace=True) 134class ReplaceOp(ReplaceOp): 135 """Specialization for PDL replace op class.""" 136 137 def __init__( 138 self, 139 op: Union[OpView, Operation, Value], 140 *, 141 with_op: Optional[Union[OpView, Operation, Value]] = None, 142 with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None, 143 loc=None, 144 ip=None, 145 ): 146 if with_values is None: 147 with_values = [] 148 op = _get_value(op) 149 with_op = with_op if with_op is None else _get_value(with_op) 150 with_values = _get_values(with_values) 151 super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip) 152 153 154@_ods_cext.register_operation(_Dialect, replace=True) 155class ResultOp(ResultOp): 156 """Specialization for PDL result op class.""" 157 158 def __init__( 159 self, 160 parent: Union[OpView, Operation, Value], 161 index: Union[IntegerAttr, int], 162 *, 163 loc=None, 164 ip=None, 165 ): 166 parent = _get_value(parent) 167 result = pdl.ValueType.get() 168 super().__init__(result, parent, index, loc=loc, ip=ip) 169 170 171@_ods_cext.register_operation(_Dialect, replace=True) 172class RewriteOp(RewriteOp): 173 """Specialization for PDL rewrite op class.""" 174 175 def __init__( 176 self, 177 root: Optional[Union[OpView, Operation, Value]] = None, 178 name: Optional[Union[StringAttr, str]] = None, 179 args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, 180 *, 181 loc=None, 182 ip=None, 183 ): 184 if args is None: 185 args = [] 186 root = root if root is None else _get_value(root) 187 args = _get_values(args) 188 super().__init__(args, root=root, name=name, loc=loc, ip=ip) 189 190 def add_body(self): 191 """Add body (block) to the rewrite.""" 192 self.regions[0].blocks.append() 193 return self.body 194 195 @property 196 def body(self): 197 """Return the body (block) of the rewrite.""" 198 return self.regions[0].blocks[0] 199 200 201rewrite = region_op(RewriteOp) 202 203 204@_ods_cext.register_operation(_Dialect, replace=True) 205class TypeOp(TypeOp): 206 """Specialization for PDL type op class.""" 207 208 def __init__( 209 self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None 210 ): 211 result = pdl.TypeType.get() 212 super().__init__(result, constantType=constantType, loc=loc, ip=ip) 213 214 215@_ods_cext.register_operation(_Dialect, replace=True) 216class TypesOp(TypesOp): 217 """Specialization for PDL types op class.""" 218 219 def __init__( 220 self, 221 constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None, 222 *, 223 loc=None, 224 ip=None, 225 ): 226 if constantTypes is None: 227 constantTypes = [] 228 result = pdl.RangeType.get(pdl.TypeType.get()) 229 super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip) 230 231 232OperationTypeT = NewType("OperationType", OperationType) 233 234 235def op_t() -> OperationTypeT: 236 return OperationTypeT(OperationType.get()) 237