xref: /llvm-project/mlir/python/mlir/dialects/transform/bufferization.py (revision a2288a8944c310fcad1196302f16513797e1fcbc)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from .._bufferization_transform_ops_gen import *
6from .._bufferization_transform_ops_gen import _Dialect
7
8try:
9    from ...ir import *
10    from ...dialects import transform
11    from .._ods_common import _cext as _ods_cext
12except ImportError as e:
13    raise RuntimeError("Error loading imports from extension module") from e
14
15from enum import Enum
16from typing import Optional, overload, Union
17
18
19@_ods_cext.register_operation(_Dialect, replace=True)
20class EmptyTensorToAllocTensorOp(EmptyTensorToAllocTensorOp):
21    """Specialization for EmptyTensorToAllocTensorOp class."""
22
23    @overload
24    def __init__(
25        self,
26        transformed_type: Type,
27        target: Union[Operation, OpView, Value],
28        *,
29        loc=None,
30        ip=None,
31    ):
32        ...
33
34    @overload
35    def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None):
36        ...
37
38    def __init__(
39        self,
40        transformed_type_or_target: Type,
41        target_or_none: Optional[Union[Operation, OpView, Value]] = None,
42        *,
43        loc=None,
44        ip=None,
45    ):
46        if isinstance(transformed_type_or_target, Type):
47            transformed_type = transformed_type_or_target
48            target = target_or_none
49        else:
50            transformed_type = transform.OperationType.get("bufferization.alloc_tensor")
51            target = transformed_type_or_target
52
53        super().__init__(
54            transformed_type,
55            target,
56            loc=loc,
57            ip=ip,
58        )
59
60
61@_ods_cext.register_operation(_Dialect, replace=True)
62class OneShotBufferizeOp(OneShotBufferizeOp):
63    """Specialization for OneShotBufferizeOp class."""
64
65    @overload
66    def __init__(
67        self,
68        transformed_type: Type,
69        target: Union[Operation, OpView, Value],
70        *,
71        allow_return_allocs_from_loops: Optional[bool] = None,
72        allow_unknown_ops: Optional[bool] = None,
73        bufferize_function_boundaries: Optional[bool] = None,
74        function_boundary_type_conversion: Optional[Enum] = None,
75        memcpy_op: Optional[str] = None,
76        print_conflicts: Optional[bool] = None,
77        test_analysis_only: Optional[bool] = None,
78        loc=None,
79        ip=None,
80    ):
81        ...
82
83    @overload
84    def __init__(
85        self,
86        target: Union[Operation, OpView, Value],
87        *,
88        allow_return_allocs_from_loops: Optional[bool] = None,
89        allow_unknown_ops: Optional[bool] = None,
90        bufferize_function_boundaries: Optional[bool] = None,
91        function_boundary_type_conversion: Optional[Enum] = None,
92        memcpy_op: Optional[str] = None,
93        print_conflicts: Optional[bool] = None,
94        test_analysis_only: Optional[bool] = None,
95        loc=None,
96        ip=None,
97    ):
98        ...
99
100    def __init__(
101        self,
102        transformed_type_or_target: Type,
103        target_or_none: Optional[Union[Operation, OpView, Value]] = None,
104        *,
105        allow_return_allocs_from_loops: Optional[bool] = None,
106        allow_unknown_ops: Optional[bool] = None,
107        bufferize_function_boundaries: Optional[bool] = None,
108        function_boundary_type_conversion: Optional[Enum] = None,
109        memcpy_op: Optional[str] = None,
110        print_conflicts: Optional[bool] = None,
111        test_analysis_only: Optional[bool] = None,
112        loc=None,
113        ip=None,
114    ):
115        if isinstance(transformed_type_or_target, Type):
116            transformed_type = transformed_type_or_target
117            target = target_or_none
118        else:
119            transformed_type = transform.AnyOpType.get()
120            target = transformed_type_or_target
121
122        super().__init__(
123            transformed_type,
124            target,
125            allow_return_allocs_from_loops=allow_return_allocs_from_loops,
126            allow_unknown_ops=allow_unknown_ops,
127            bufferize_function_boundaries=bufferize_function_boundaries,
128            function_boundary_type_conversion=function_boundary_type_conversion,
129            memcpy_op=memcpy_op,
130            print_conflicts=print_conflicts,
131            test_analysis_only=test_analysis_only,
132            loc=loc,
133            ip=ip,
134        )
135