xref: /llvm-project/mlir/python/mlir/dialects/transform/memref.py (revision a2288a8944c310fcad1196302f16513797e1fcbc)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from .._memref_transform_ops_gen import *
6from .._memref_transform_ops_gen import _Dialect
7
8try:
9    from ...ir import *
10    from ...dialects import transform
11    from .._ods_common import _cext as _ods_cext
12except ImportError as e:
13    raise RuntimeError("Error loading imports from extension module") from e
14
15from typing import Optional, overload, Union
16
17
18@_ods_cext.register_operation(_Dialect, replace=True)
19class MemRefAllocaToGlobalOp(MemRefAllocaToGlobalOp):
20    """Specialization for MemRefAllocaToGlobalOp class."""
21
22    @overload
23    def __init__(
24        self,
25        get_global_type: Type,
26        global_type: Type,
27        alloca: Union[Operation, OpView, Value],
28        *,
29        loc=None,
30        ip=None,
31    ):
32        ...
33
34    @overload
35    def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None):
36        ...
37
38    def __init__(
39        self,
40        get_global_type_or_alloca: Union[Operation, OpView, Type, Value],
41        global_type_or_none: Optional[Type] = None,
42        alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
43        *,
44        loc=None,
45        ip=None,
46    ):
47        if isinstance(get_global_type_or_alloca, Type):
48            get_global_type = get_global_type_or_alloca
49            global_type = global_type_or_none
50            alloca = alloca_or_none
51        else:
52            get_global_type = transform.AnyOpType.get()
53            global_type = transform.AnyOpType.get()
54            alloca = get_global_type_or_alloca
55
56        super().__init__(
57            get_global_type,
58            global_type,
59            alloca,
60            loc=loc,
61            ip=ip,
62        )
63
64
65@_ods_cext.register_operation(_Dialect, replace=True)
66class MemRefMultiBufferOp(MemRefMultiBufferOp):
67    """Specialization for MemRefMultiBufferOp class."""
68
69    @overload
70    def __init__(
71        self,
72        transformed_type: Type,
73        target: Union[Operation, OpView, Value],
74        factor: Union[int, IntegerAttr],
75        *,
76        skip_analysis: Optional[bool] = None,
77        loc=None,
78        ip=None,
79    ):
80        ...
81
82    @overload
83    def __init__(
84        self,
85        target: Union[Operation, OpView, Value],
86        factor: Union[int, IntegerAttr],
87        *,
88        skip_analysis: Optional[bool] = None,
89        loc=None,
90        ip=None,
91    ):
92        ...
93
94    def __init__(
95        self,
96        transformed_type_or_target: Type,
97        target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None,
98        factor_or_none: Optional[Union[int, IntegerAttr]] = None,
99        *,
100        skip_analysis: Optional[bool] = None,
101        loc=None,
102        ip=None,
103    ):
104        if isinstance(transformed_type_or_target, Type):
105            transformed_type = transformed_type_or_target
106            target = target_or_factor
107            factor = factor_or_none
108        else:
109            transformed_type = transform.AnyOpType.get()
110            target = transformed_type_or_target
111            factor = target_or_factor
112
113        super().__init__(
114            transformed_type,
115            target,
116            factor,
117            skip_analysis=skip_analysis,
118            loc=loc,
119            ip=ip,
120        )
121