xref: /llvm-project/mlir/python/mlir/dialects/transform/tensor.py (revision a2288a8944c310fcad1196302f16513797e1fcbc)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from .._tensor_transform_ops_gen import *
6from .._tensor_transform_ops_gen import _Dialect
7
8try:
9    from ...ir import *
10    from ...dialects import transform
11    from .._ods_common import _cext as _ods_cext
12except ImportError as e:
13    raise RuntimeError("Error loading imports from extension module") from e
14
15from typing import Optional, overload, Union
16
17
18@_ods_cext.register_operation(_Dialect, replace=True)
19class MakeLoopIndependentOp(MakeLoopIndependentOp):
20    """Specialization for MakeLoopIndependentOp class."""
21
22    @overload
23    def __init__(
24        self,
25        transformed_type: Type,
26        target: Union[Operation, OpView, Value],
27        num_loops: Union[int, IntegerAttr],
28        *,
29        loc=None,
30        ip=None,
31    ):
32        ...
33
34    @overload
35    def __init__(
36        self,
37        target: Union[Operation, OpView, Value],
38        num_loops: Union[int, IntegerAttr],
39        *,
40        loc=None,
41        ip=None,
42    ):
43        ...
44
45    def __init__(
46        self,
47        transformed_type_or_target: Type,
48        target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None,
49        num_loops_or_none: Optional[Union[int, IntegerAttr]] = None,
50        *,
51        loc=None,
52        ip=None,
53    ):
54        if isinstance(transformed_type_or_target, Type):
55            transformed_type = transformed_type_or_target
56            target = target_or_num_loops
57            num_loops = num_loops_or_none
58        else:
59            transformed_type = transform.AnyOpType.get()
60            target = transformed_type_or_target
61            num_loops = target_or_num_loops
62
63        super().__init__(
64            transformed_type,
65            target,
66            num_loops,
67            loc=loc,
68            ip=ip,
69        )
70