xref: /llvm-project/mlir/python/mlir/dialects/transform/loop.py (revision 4c654b7b91aff61728619fc3cc955fa5169d17c6)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from .._loop_transform_ops_gen import *
6from .._loop_transform_ops_gen import _Dialect
7
8try:
9    from ...ir import *
10    from .._ods_common import (
11        get_op_result_or_value as _get_op_result_or_value,
12        _cext as _ods_cext,
13    )
14except ImportError as e:
15    raise RuntimeError("Error loading imports from extension module") from e
16
17from typing import Optional, Union
18
19
20@_ods_cext.register_operation(_Dialect, replace=True)
21class LoopOutlineOp(LoopOutlineOp):
22    """Extension for LoopOutlineOp."""
23
24    def __init__(
25        self,
26        function_type: Type,
27        call_type: Type,
28        target: Union[Operation, Value],
29        *,
30        func_name: Union[str, StringAttr],
31        ip=None,
32        loc=None,
33    ):
34        super().__init__(
35            function_type,
36            call_type,
37            _get_op_result_or_value(target),
38            func_name=(
39                func_name
40                if isinstance(func_name, StringAttr)
41                else StringAttr.get(func_name)
42            ),
43            ip=ip,
44            loc=loc,
45        )
46
47
48@_ods_cext.register_operation(_Dialect, replace=True)
49class LoopPeelOp(LoopPeelOp):
50    """Extension for LoopPeelOp."""
51
52    def __init__(
53        self,
54        main_loop_type: Type,
55        remainder_loop_type: Type,
56        target: Union[Operation, Value],
57        *,
58        peel_front: Union[bool, BoolAttr] = False,
59        fail_if_already_divisible: Union[bool, BoolAttr] = False,
60        ip=None,
61        loc=None,
62    ):
63        super().__init__(
64            main_loop_type,
65            remainder_loop_type,
66            _get_op_result_or_value(target),
67            peel_front=(
68                peel_front
69                if isinstance(peel_front, BoolAttr)
70                else BoolAttr.get(peel_front)
71            ),
72            fail_if_already_divisible=(
73                fail_if_already_divisible
74                if isinstance(fail_if_already_divisible, BoolAttr)
75                else BoolAttr.get(fail_if_already_divisible)
76            ),
77            ip=ip,
78            loc=loc,
79        )
80
81
82@_ods_cext.register_operation(_Dialect, replace=True)
83class LoopPipelineOp(LoopPipelineOp):
84    """Extension for LoopPipelineOp."""
85
86    def __init__(
87        self,
88        result_type: Type,
89        target: Union[Operation, Value],
90        *,
91        iteration_interval: Optional[Union[int, IntegerAttr]] = None,
92        read_latency: Optional[Union[int, IntegerAttr]] = None,
93        ip=None,
94        loc=None,
95    ):
96        if iteration_interval is None:
97            iteration_interval = 1
98        if read_latency is None:
99            read_latency = 10
100        super().__init__(
101            result_type,
102            _get_op_result_or_value(target),
103            iteration_interval=iteration_interval,
104            read_latency=read_latency,
105            ip=ip,
106            loc=loc,
107        )
108
109
110@_ods_cext.register_operation(_Dialect, replace=True)
111class LoopUnrollOp(LoopUnrollOp):
112    """Extension for LoopUnrollOp."""
113
114    def __init__(
115        self,
116        target: Union[Operation, Value],
117        *,
118        factor: Union[int, IntegerAttr],
119        ip=None,
120        loc=None,
121    ):
122        super().__init__(
123            _get_op_result_or_value(target),
124            factor=factor,
125            ip=ip,
126            loc=loc,
127        )
128