1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from .._transform_pdl_extension_ops_gen import * 6from .._transform_pdl_extension_ops_gen import _Dialect 7 8try: 9 from ...ir import * 10 from .._ods_common import ( 11 get_op_result_or_value as _get_op_result_or_value, 12 get_op_results_or_values as _get_op_results_or_values, 13 _cext as _ods_cext, 14 ) 15except ImportError as e: 16 raise RuntimeError("Error loading imports from extension module") from e 17 18from typing import Union 19 20 21@_ods_cext.register_operation(_Dialect, replace=True) 22class PDLMatchOp(PDLMatchOp): 23 def __init__( 24 self, 25 result_type: Type, 26 target: Union[Operation, Value], 27 pattern_name: Union[Attribute, str], 28 *, 29 loc=None, 30 ip=None, 31 ): 32 super().__init__( 33 result_type, 34 _get_op_result_or_value(target), 35 pattern_name, 36 loc=loc, 37 ip=ip, 38 ) 39 40 41@_ods_cext.register_operation(_Dialect, replace=True) 42class WithPDLPatternsOp(WithPDLPatternsOp): 43 def __init__(self, target: Union[Operation, Value, Type], *, loc=None, ip=None): 44 root = _get_op_result_or_value(target) if not isinstance(target, Type) else None 45 root_type = target if isinstance(target, Type) else root.type 46 super().__init__(root=root, loc=loc, ip=ip) 47 self.regions[0].blocks.append(root_type) 48 49 @property 50 def body(self) -> Block: 51 return self.regions[0].blocks[0] 52 53 @property 54 def bodyTarget(self) -> Value: 55 return self.body.arguments[0] 56