1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from .._transform_enum_gen import * 6from .._transform_ops_gen import * 7from .._transform_ops_gen import _Dialect 8from ..._mlir_libs._mlirDialectsTransform import * 9from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType 10 11try: 12 from ...ir import * 13 from .._ods_common import ( 14 get_op_result_or_value as _get_op_result_or_value, 15 get_op_results_or_values as _get_op_results_or_values, 16 _cext as _ods_cext, 17 ) 18except ImportError as e: 19 raise RuntimeError("Error loading imports from extension module") from e 20 21from typing import Optional, Sequence, Union, NewType 22 23 24@_ods_cext.register_operation(_Dialect, replace=True) 25class CastOp(CastOp): 26 def __init__( 27 self, 28 result_type: Type, 29 target: Union[Operation, Value], 30 *, 31 loc=None, 32 ip=None, 33 ): 34 super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) 35 36 37@_ods_cext.register_operation(_Dialect, replace=True) 38class ApplyPatternsOp(ApplyPatternsOp): 39 def __init__( 40 self, 41 target: Union[Operation, Value, OpView], 42 *, 43 loc=None, 44 ip=None, 45 ): 46 super().__init__(target, loc=loc, ip=ip) 47 self.regions[0].blocks.append() 48 49 @property 50 def patterns(self) -> Block: 51 return self.regions[0].blocks[0] 52 53 54@_ods_cext.register_operation(_Dialect, replace=True) 55class GetParentOp(GetParentOp): 56 def __init__( 57 self, 58 result_type: Type, 59 target: Union[Operation, Value], 60 *, 61 isolated_from_above: bool = False, 62 op_name: Optional[str] = None, 63 deduplicate: bool = False, 64 nth_parent: int = 1, 65 loc=None, 66 ip=None, 67 ): 68 super().__init__( 69 result_type, 70 _get_op_result_or_value(target), 71 isolated_from_above=isolated_from_above, 72 op_name=op_name, 73 deduplicate=deduplicate, 74 nth_parent=nth_parent, 75 loc=loc, 76 ip=ip, 77 ) 78 79 80@_ods_cext.register_operation(_Dialect, replace=True) 81class MergeHandlesOp(MergeHandlesOp): 82 def __init__( 83 self, 84 handles: Sequence[Union[Operation, Value]], 85 *, 86 deduplicate: bool = False, 87 loc=None, 88 ip=None, 89 ): 90 super().__init__( 91 [_get_op_result_or_value(h) for h in handles], 92 deduplicate=deduplicate, 93 loc=loc, 94 ip=ip, 95 ) 96 97 98@_ods_cext.register_operation(_Dialect, replace=True) 99class ReplicateOp(ReplicateOp): 100 def __init__( 101 self, 102 pattern: Union[Operation, Value], 103 handles: Sequence[Union[Operation, Value]], 104 *, 105 loc=None, 106 ip=None, 107 ): 108 super().__init__( 109 [_get_op_result_or_value(h).type for h in handles], 110 _get_op_result_or_value(pattern), 111 [_get_op_result_or_value(h) for h in handles], 112 loc=loc, 113 ip=ip, 114 ) 115 116 117@_ods_cext.register_operation(_Dialect, replace=True) 118class SequenceOp(SequenceOp): 119 def __init__( 120 self, 121 failure_propagation_mode, 122 results: Sequence[Type], 123 target: Union[Operation, Value, Type], 124 extra_bindings: Optional[ 125 Union[Sequence[Value], Sequence[Type], Operation, OpView] 126 ] = None, 127 ): 128 root = ( 129 _get_op_result_or_value(target) 130 if isinstance(target, (Operation, Value)) 131 else None 132 ) 133 root_type = root.type if not isinstance(target, Type) else target 134 135 if extra_bindings is None: 136 extra_bindings = [] 137 if isinstance(extra_bindings, (Operation, OpView)): 138 extra_bindings = _get_op_results_or_values(extra_bindings) 139 140 extra_binding_types = [] 141 if len(extra_bindings) != 0: 142 if isinstance(extra_bindings[0], Type): 143 extra_binding_types = extra_bindings 144 extra_bindings = [] 145 else: 146 extra_binding_types = [v.type for v in extra_bindings] 147 148 super().__init__( 149 results_=results, 150 failure_propagation_mode=failure_propagation_mode, 151 root=root, 152 extra_bindings=extra_bindings, 153 ) 154 self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types)) 155 156 @property 157 def body(self) -> Block: 158 return self.regions[0].blocks[0] 159 160 @property 161 def bodyTarget(self) -> Value: 162 return self.body.arguments[0] 163 164 @property 165 def bodyExtraArgs(self) -> BlockArgumentList: 166 return self.body.arguments[1:] 167 168 169@_ods_cext.register_operation(_Dialect, replace=True) 170class NamedSequenceOp(NamedSequenceOp): 171 def __init__( 172 self, 173 sym_name, 174 input_types: Sequence[Type], 175 result_types: Sequence[Type], 176 sym_visibility=None, 177 arg_attrs=None, 178 res_attrs=None, 179 ): 180 function_type = FunctionType.get(input_types, result_types) 181 super().__init__( 182 sym_name=sym_name, 183 function_type=TypeAttr.get(function_type), 184 sym_visibility=sym_visibility, 185 arg_attrs=arg_attrs, 186 res_attrs=res_attrs, 187 ) 188 self.regions[0].blocks.append(*input_types) 189 190 @property 191 def body(self) -> Block: 192 return self.regions[0].blocks[0] 193 194 @property 195 def bodyTarget(self) -> Value: 196 return self.body.arguments[0] 197 198 @property 199 def bodyExtraArgs(self) -> BlockArgumentList: 200 return self.body.arguments[1:] 201 202 203@_ods_cext.register_operation(_Dialect, replace=True) 204class YieldOp(YieldOp): 205 def __init__( 206 self, 207 operands: Optional[Union[Operation, Sequence[Value]]] = None, 208 *, 209 loc=None, 210 ip=None, 211 ): 212 if operands is None: 213 operands = [] 214 super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) 215 216 217AnyOpTypeT = NewType("AnyOpType", AnyOpType) 218 219 220def any_op_t() -> AnyOpTypeT: 221 return AnyOpTypeT(AnyOpType.get()) 222