1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from .._gpu_transform_ops_gen import * 6from .._gpu_transform_ops_gen import _Dialect 7 8try: 9 from ...ir import * 10 from ...dialects import transform 11 from .._ods_common import _cext as _ods_cext 12except ImportError as e: 13 raise RuntimeError("Error loading imports from extension module") from e 14 15from typing import Optional, Sequence, Union, overload 16 17 18@_ods_cext.register_operation(_Dialect, replace=True) 19class MapForallToBlocks(MapForallToBlocks): 20 """Specialization for MapForallToBlocks class.""" 21 22 @overload 23 def __init__( 24 self, 25 result_type: Type, 26 target: Union[Operation, OpView, Value], 27 *, 28 grid_dims: Optional[Union[Sequence[int], Attribute]] = None, 29 generate_gpu_launch: Optional[Union[bool, Attribute]] = None, 30 loc=None, 31 ip=None, 32 ): 33 ... 34 35 @overload 36 def __init__( 37 self, 38 target: Union[Operation, OpView, Value], 39 *, 40 grid_dims: Optional[Union[Sequence[int], Attribute]] = None, 41 generate_gpu_launch: Optional[Union[bool, Attribute]] = None, 42 loc=None, 43 ip=None, 44 ): 45 ... 46 47 def __init__( 48 self, 49 result_type_or_target: Union[Operation, OpView, Type, Value], 50 target_or_none: Optional[Union[Operation, OpView, Value]] = None, 51 *, 52 grid_dims: Optional[Union[Sequence[int], Attribute]] = None, 53 generate_gpu_launch: Optional[Union[bool, Attribute]] = None, 54 loc=None, 55 ip=None, 56 ): 57 if isinstance(result_type_or_target, Type): 58 result_type = result_type_or_target 59 target = target_or_none 60 else: 61 result_type = transform.AnyOpType.get() 62 target = result_type_or_target 63 64 super().__init__( 65 result_type, 66 target, 67 grid_dims=grid_dims, 68 generate_gpu_launch=generate_gpu_launch, 69 loc=loc, 70 ip=ip, 71 ) 72 73 74@_ods_cext.register_operation(_Dialect, replace=True) 75class MapNestedForallToThreads(MapNestedForallToThreads): 76 """Specialization for MapNestedForallToThreads class.""" 77 78 @overload 79 def __init__( 80 self, 81 result_type: Type, 82 target: Union[Operation, OpView, Value], 83 *, 84 block_dims: Optional[Sequence[int]] = None, 85 warp_size: Optional[Sequence[int]] = None, 86 sync_after_distribute: Optional[bool] = None, 87 loc=None, 88 ip=None, 89 ): 90 ... 91 92 @overload 93 def __init__( 94 self, 95 target: Union[Operation, OpView, Value], 96 *, 97 block_dims: Optional[Sequence[int]] = None, 98 warp_size: Optional[Sequence[int]] = None, 99 sync_after_distribute: Optional[bool] = None, 100 loc=None, 101 ip=None, 102 ): 103 ... 104 105 def __init__( 106 self, 107 result_type_or_target: Union[Operation, OpView, Value, Type], 108 target_or_none: Optional[Union[Operation, OpView, Value]] = None, 109 *, 110 block_dims: Optional[Union[Sequence[int], Attribute]] = None, 111 warp_size: Optional[Union[Sequence[int], Attribute]] = None, 112 sync_after_distribute: Optional[bool] = None, 113 loc=None, 114 ip=None, 115 ): 116 if isinstance(result_type_or_target, Type): 117 result_type = result_type_or_target 118 target = target_or_none 119 else: 120 result_type = result_type_or_target.type 121 target = result_type_or_target 122 super().__init__( 123 result_type, 124 target, 125 block_dims=block_dims, 126 warp_size=warp_size, 127 sync_after_distribute=sync_after_distribute, 128 loc=loc, 129 ip=ip, 130 ) 131