1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from typing import ( 6 List as _List, 7 Optional as _Optional, 8 Sequence as _Sequence, 9 Tuple as _Tuple, 10 Type as _Type, 11 Union as _Union, 12) 13 14from .._mlir_libs import _mlir as _cext 15from ..ir import ( 16 ArrayAttr, 17 Attribute, 18 BoolAttr, 19 DenseI64ArrayAttr, 20 IntegerAttr, 21 IntegerType, 22 OpView, 23 Operation, 24 ShapedType, 25 Value, 26) 27 28__all__ = [ 29 "equally_sized_accessor", 30 "get_default_loc_context", 31 "get_op_result_or_value", 32 "get_op_results_or_values", 33 "get_op_result_or_op_results", 34 "segmented_accessor", 35] 36 37 38def segmented_accessor(elements, raw_segments, idx): 39 """ 40 Returns a slice of elements corresponding to the idx-th segment. 41 42 elements: a sliceable container (operands or results). 43 raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing 44 sizes of the segments. 45 idx: index of the segment. 46 """ 47 segments = _cext.ir.DenseI32ArrayAttr(raw_segments) 48 start = sum(segments[i] for i in range(idx)) 49 end = start + segments[idx] 50 return elements[start:end] 51 52 53def equally_sized_accessor( 54 elements, n_simple, n_variadic, n_preceding_simple, n_preceding_variadic 55): 56 """ 57 Returns a starting position and a number of elements per variadic group 58 assuming equally-sized groups and the given numbers of preceding groups. 59 60 elements: a sequential container. 61 n_simple: the number of non-variadic groups in the container. 62 n_variadic: the number of variadic groups in the container. 63 n_preceding_simple: the number of non-variadic groups preceding the current 64 group. 65 n_preceding_variadic: the number of variadic groups preceding the current 66 group. 67 """ 68 69 total_variadic_length = len(elements) - n_simple 70 # This should be enforced by the C++-side trait verifier. 71 assert total_variadic_length % n_variadic == 0 72 73 elements_per_group = total_variadic_length // n_variadic 74 start = n_preceding_simple + n_preceding_variadic * elements_per_group 75 return start, elements_per_group 76 77 78def get_default_loc_context(location=None): 79 """ 80 Returns a context in which the defaulted location is created. If the location 81 is None, takes the current location from the stack, raises ValueError if there 82 is no location on the stack. 83 """ 84 if location is None: 85 # Location.current raises ValueError if there is no current location. 86 return _cext.ir.Location.current.context 87 return location.context 88 89 90def get_op_result_or_value( 91 arg: _Union[ 92 _cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList 93 ] 94) -> _cext.ir.Value: 95 """Returns the given value or the single result of the given op. 96 97 This is useful to implement op constructors so that they can take other ops as 98 arguments instead of requiring the caller to extract results for every op. 99 Raises ValueError if provided with an op that doesn't have a single result. 100 """ 101 if isinstance(arg, _cext.ir.OpView): 102 return arg.operation.result 103 elif isinstance(arg, _cext.ir.Operation): 104 return arg.result 105 elif isinstance(arg, _cext.ir.OpResultList): 106 return arg[0] 107 else: 108 assert isinstance(arg, _cext.ir.Value) 109 return arg 110 111 112def get_op_results_or_values( 113 arg: _Union[ 114 _cext.ir.OpView, 115 _cext.ir.Operation, 116 _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]], 117 ] 118) -> _Union[ 119 _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]], 120 _cext.ir.OpResultList, 121]: 122 """Returns the given sequence of values or the results of the given op. 123 124 This is useful to implement op constructors so that they can take other ops as 125 lists of arguments instead of requiring the caller to extract results for 126 every op. 127 """ 128 if isinstance(arg, _cext.ir.OpView): 129 return arg.operation.results 130 elif isinstance(arg, _cext.ir.Operation): 131 return arg.results 132 else: 133 return arg 134 135 136def get_op_result_or_op_results( 137 op: _Union[_cext.ir.OpView, _cext.ir.Operation], 138) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]: 139 results = op.results 140 num_results = len(results) 141 if num_results == 1: 142 return results[0] 143 elif num_results > 1: 144 return results 145 elif isinstance(op, _cext.ir.OpView): 146 return op.operation 147 else: 148 return op 149 150ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value 151ResultValueT = _Union[ResultValueTypeTuple] 152VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]] 153 154StaticIntLike = _Union[int, IntegerAttr] 155ValueLike = _Union[Operation, OpView, Value] 156MixedInt = _Union[StaticIntLike, ValueLike] 157 158IntOrAttrList = _Sequence[_Union[IntegerAttr, int]] 159OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]] 160 161BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]] 162OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]] 163 164MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike] 165 166DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]] 167 168 169def _dispatch_dynamic_index_list( 170 indices: _Union[DynamicIndexList, ArrayAttr], 171) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]: 172 """Dispatches a list of indices to the appropriate form. 173 174 This is similar to the custom `DynamicIndexList` directive upstream: 175 provided indices may be in the form of dynamic SSA values or static values, 176 and they may be scalable (i.e., as a singleton list) or not. This function 177 dispatches each index into its respective form. It also extracts the SSA 178 values and static indices from various similar structures, respectively. 179 """ 180 dynamic_indices = [] 181 static_indices = [ShapedType.get_dynamic_size()] * len(indices) 182 scalable_indices = [False] * len(indices) 183 184 # ArrayAttr: Extract index values. 185 if isinstance(indices, ArrayAttr): 186 indices = [idx for idx in indices] 187 188 def process_nonscalable_index(i, index): 189 """Processes any form of non-scalable index. 190 191 Returns False if the given index was scalable and thus remains 192 unprocessed; True otherwise. 193 """ 194 if isinstance(index, int): 195 static_indices[i] = index 196 elif isinstance(index, IntegerAttr): 197 static_indices[i] = index.value # pytype: disable=attribute-error 198 elif isinstance(index, (Operation, Value, OpView)): 199 dynamic_indices.append(index) 200 else: 201 return False 202 return True 203 204 # Process each index at a time. 205 for i, index in enumerate(indices): 206 if not process_nonscalable_index(i, index): 207 # If it wasn't processed, it must be a scalable index, which is 208 # provided as a _Sequence of one value, so extract and process that. 209 scalable_indices[i] = True 210 assert len(index) == 1 211 ret = process_nonscalable_index(i, index[0]) 212 assert ret 213 214 return dynamic_indices, static_indices, scalable_indices 215 216 217# Dispatches `MixedValues` that all represents integers in various forms into 218# the following three categories: 219# - `dynamic_values`: a list of `Value`s, potentially from op results; 220# - `packed_values`: a value handle, potentially from an op result, associated 221# to one or more payload operations of integer type; 222# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python 223# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. 224# The input is in the form for `packed_values`, only that result is set and the 225# other two are empty. Otherwise, the input can be a mix of the other two forms, 226# and for each dynamic value, a special value is added to the `static_values`. 227def _dispatch_mixed_values( 228 values: MixedValues, 229) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]: 230 dynamic_values = [] 231 packed_values = None 232 static_values = None 233 if isinstance(values, ArrayAttr): 234 static_values = values 235 elif isinstance(values, (Operation, Value, OpView)): 236 packed_values = values 237 else: 238 static_values = [] 239 for size in values or []: 240 if isinstance(size, int): 241 static_values.append(size) 242 else: 243 static_values.append(ShapedType.get_dynamic_size()) 244 dynamic_values.append(size) 245 static_values = DenseI64ArrayAttr.get(static_values) 246 247 return (dynamic_values, packed_values, static_values) 248 249 250def _get_value_or_attribute_value( 251 value_or_attr: _Union[any, Attribute, ArrayAttr] 252) -> any: 253 if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"): 254 return value_or_attr.value 255 if isinstance(value_or_attr, ArrayAttr): 256 return _get_value_list(value_or_attr) 257 return value_or_attr 258 259 260def _get_value_list( 261 sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr] 262) -> _Sequence[any]: 263 return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr] 264 265 266def _get_int_array_attr( 267 values: _Optional[_Union[ArrayAttr, IntOrAttrList]] 268) -> ArrayAttr: 269 if values is None: 270 return None 271 272 # Turn into a Python list of Python ints. 273 values = _get_value_list(values) 274 275 # Make an ArrayAttr of IntegerAttrs out of it. 276 return ArrayAttr.get( 277 [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values] 278 ) 279 280 281def _get_int_array_array_attr( 282 values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]] 283) -> ArrayAttr: 284 """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs. 285 286 The input has to be a collection of a collection of integers, where any 287 Python _Sequence and ArrayAttr are admissible collections and Python ints and 288 any IntegerAttr are admissible integers. Both levels of collections are 289 turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s. 290 If the input is None, an empty ArrayAttr is returned. 291 """ 292 if values is None: 293 return None 294 295 # Make sure the outer level is a list. 296 values = _get_value_list(values) 297 298 # The inner level is now either invalid or a mixed sequence of ArrayAttrs and 299 # Sequences. Make sure the nested values are all lists. 300 values = [_get_value_list(nested) for nested in values] 301 302 # Turn each nested list into an ArrayAttr. 303 values = [_get_int_array_attr(nested) for nested in values] 304 305 # Turn the outer list into an ArrayAttr. 306 return ArrayAttr.get(values) 307