1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from .._memref_transform_ops_gen import * 6from .._memref_transform_ops_gen import _Dialect 7 8try: 9 from ...ir import * 10 from ...dialects import transform 11 from .._ods_common import _cext as _ods_cext 12except ImportError as e: 13 raise RuntimeError("Error loading imports from extension module") from e 14 15from typing import Optional, overload, Union 16 17 18@_ods_cext.register_operation(_Dialect, replace=True) 19class MemRefAllocaToGlobalOp(MemRefAllocaToGlobalOp): 20 """Specialization for MemRefAllocaToGlobalOp class.""" 21 22 @overload 23 def __init__( 24 self, 25 get_global_type: Type, 26 global_type: Type, 27 alloca: Union[Operation, OpView, Value], 28 *, 29 loc=None, 30 ip=None, 31 ): 32 ... 33 34 @overload 35 def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None): 36 ... 37 38 def __init__( 39 self, 40 get_global_type_or_alloca: Union[Operation, OpView, Type, Value], 41 global_type_or_none: Optional[Type] = None, 42 alloca_or_none: Optional[Union[Operation, OpView, Value]] = None, 43 *, 44 loc=None, 45 ip=None, 46 ): 47 if isinstance(get_global_type_or_alloca, Type): 48 get_global_type = get_global_type_or_alloca 49 global_type = global_type_or_none 50 alloca = alloca_or_none 51 else: 52 get_global_type = transform.AnyOpType.get() 53 global_type = transform.AnyOpType.get() 54 alloca = get_global_type_or_alloca 55 56 super().__init__( 57 get_global_type, 58 global_type, 59 alloca, 60 loc=loc, 61 ip=ip, 62 ) 63 64 65@_ods_cext.register_operation(_Dialect, replace=True) 66class MemRefMultiBufferOp(MemRefMultiBufferOp): 67 """Specialization for MemRefMultiBufferOp class.""" 68 69 @overload 70 def __init__( 71 self, 72 transformed_type: Type, 73 target: Union[Operation, OpView, Value], 74 factor: Union[int, IntegerAttr], 75 *, 76 skip_analysis: Optional[bool] = None, 77 loc=None, 78 ip=None, 79 ): 80 ... 81 82 @overload 83 def __init__( 84 self, 85 target: Union[Operation, OpView, Value], 86 factor: Union[int, IntegerAttr], 87 *, 88 skip_analysis: Optional[bool] = None, 89 loc=None, 90 ip=None, 91 ): 92 ... 93 94 def __init__( 95 self, 96 transformed_type_or_target: Type, 97 target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None, 98 factor_or_none: Optional[Union[int, IntegerAttr]] = None, 99 *, 100 skip_analysis: Optional[bool] = None, 101 loc=None, 102 ip=None, 103 ): 104 if isinstance(transformed_type_or_target, Type): 105 transformed_type = transformed_type_or_target 106 target = target_or_factor 107 factor = factor_or_none 108 else: 109 transformed_type = transform.AnyOpType.get() 110 target = transformed_type_or_target 111 factor = target_or_factor 112 113 super().__init__( 114 transformed_type, 115 target, 116 factor, 117 skip_analysis=skip_analysis, 118 loc=loc, 119 ip=ip, 120 ) 121