xref: /llvm-project/mlir/python/mlir/dialects/transform/gpu.py (revision a2288a8944c310fcad1196302f16513797e1fcbc)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from .._gpu_transform_ops_gen import *
6from .._gpu_transform_ops_gen import _Dialect
7
8try:
9    from ...ir import *
10    from ...dialects import transform
11    from .._ods_common import _cext as _ods_cext
12except ImportError as e:
13    raise RuntimeError("Error loading imports from extension module") from e
14
15from typing import Optional, Sequence, Union, overload
16
17
18@_ods_cext.register_operation(_Dialect, replace=True)
19class MapForallToBlocks(MapForallToBlocks):
20    """Specialization for MapForallToBlocks class."""
21
22    @overload
23    def __init__(
24        self,
25        result_type: Type,
26        target: Union[Operation, OpView, Value],
27        *,
28        grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
29        generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
30        loc=None,
31        ip=None,
32    ):
33        ...
34
35    @overload
36    def __init__(
37        self,
38        target: Union[Operation, OpView, Value],
39        *,
40        grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
41        generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
42        loc=None,
43        ip=None,
44    ):
45        ...
46
47    def __init__(
48        self,
49        result_type_or_target: Union[Operation, OpView, Type, Value],
50        target_or_none: Optional[Union[Operation, OpView, Value]] = None,
51        *,
52        grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
53        generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
54        loc=None,
55        ip=None,
56    ):
57        if isinstance(result_type_or_target, Type):
58            result_type = result_type_or_target
59            target = target_or_none
60        else:
61            result_type = transform.AnyOpType.get()
62            target = result_type_or_target
63
64        super().__init__(
65            result_type,
66            target,
67            grid_dims=grid_dims,
68            generate_gpu_launch=generate_gpu_launch,
69            loc=loc,
70            ip=ip,
71        )
72
73
74@_ods_cext.register_operation(_Dialect, replace=True)
75class MapNestedForallToThreads(MapNestedForallToThreads):
76    """Specialization for MapNestedForallToThreads class."""
77
78    @overload
79    def __init__(
80        self,
81        result_type: Type,
82        target: Union[Operation, OpView, Value],
83        *,
84        block_dims: Optional[Sequence[int]] = None,
85        warp_size: Optional[Sequence[int]] = None,
86        sync_after_distribute: Optional[bool] = None,
87        loc=None,
88        ip=None,
89    ):
90        ...
91
92    @overload
93    def __init__(
94        self,
95        target: Union[Operation, OpView, Value],
96        *,
97        block_dims: Optional[Sequence[int]] = None,
98        warp_size: Optional[Sequence[int]] = None,
99        sync_after_distribute: Optional[bool] = None,
100        loc=None,
101        ip=None,
102    ):
103        ...
104
105    def __init__(
106        self,
107        result_type_or_target: Union[Operation, OpView, Value, Type],
108        target_or_none: Optional[Union[Operation, OpView, Value]] = None,
109        *,
110        block_dims: Optional[Union[Sequence[int], Attribute]] = None,
111        warp_size: Optional[Union[Sequence[int], Attribute]] = None,
112        sync_after_distribute: Optional[bool] = None,
113        loc=None,
114        ip=None,
115    ):
116        if isinstance(result_type_or_target, Type):
117            result_type = result_type_or_target
118            target = target_or_none
119        else:
120            result_type = result_type_or_target.type
121            target = result_type_or_target
122        super().__init__(
123            result_type,
124            target,
125            block_dims=block_dims,
126            warp_size=warp_size,
127            sync_after_distribute=sync_after_distribute,
128            loc=loc,
129            ip=ip,
130        )
131