xref: /llvm-project/mlir/python/mlir/dialects/_ods_common.py (revision acde3f722ff3766f6f793884108d342b78623fe4)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from typing import (
6    List as _List,
7    Optional as _Optional,
8    Sequence as _Sequence,
9    Tuple as _Tuple,
10    Type as _Type,
11    Union as _Union,
12)
13
14from .._mlir_libs import _mlir as _cext
15from ..ir import (
16    ArrayAttr,
17    Attribute,
18    BoolAttr,
19    DenseI64ArrayAttr,
20    IntegerAttr,
21    IntegerType,
22    OpView,
23    Operation,
24    ShapedType,
25    Value,
26)
27
28__all__ = [
29    "equally_sized_accessor",
30    "get_default_loc_context",
31    "get_op_result_or_value",
32    "get_op_results_or_values",
33    "get_op_result_or_op_results",
34    "segmented_accessor",
35]
36
37
38def segmented_accessor(elements, raw_segments, idx):
39    """
40    Returns a slice of elements corresponding to the idx-th segment.
41
42      elements: a sliceable container (operands or results).
43      raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing
44          sizes of the segments.
45      idx: index of the segment.
46    """
47    segments = _cext.ir.DenseI32ArrayAttr(raw_segments)
48    start = sum(segments[i] for i in range(idx))
49    end = start + segments[idx]
50    return elements[start:end]
51
52
53def equally_sized_accessor(
54    elements, n_simple, n_variadic, n_preceding_simple, n_preceding_variadic
55):
56    """
57    Returns a starting position and a number of elements per variadic group
58    assuming equally-sized groups and the given numbers of preceding groups.
59
60      elements: a sequential container.
61      n_simple: the number of non-variadic groups in the container.
62      n_variadic: the number of variadic groups in the container.
63      n_preceding_simple: the number of non-variadic groups preceding the current
64          group.
65      n_preceding_variadic: the number of variadic groups preceding the current
66          group.
67    """
68
69    total_variadic_length = len(elements) - n_simple
70    # This should be enforced by the C++-side trait verifier.
71    assert total_variadic_length % n_variadic == 0
72
73    elements_per_group = total_variadic_length // n_variadic
74    start = n_preceding_simple + n_preceding_variadic * elements_per_group
75    return start, elements_per_group
76
77
78def get_default_loc_context(location=None):
79    """
80    Returns a context in which the defaulted location is created. If the location
81    is None, takes the current location from the stack, raises ValueError if there
82    is no location on the stack.
83    """
84    if location is None:
85        # Location.current raises ValueError if there is no current location.
86        return _cext.ir.Location.current.context
87    return location.context
88
89
90def get_op_result_or_value(
91    arg: _Union[
92        _cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList
93    ]
94) -> _cext.ir.Value:
95    """Returns the given value or the single result of the given op.
96
97    This is useful to implement op constructors so that they can take other ops as
98    arguments instead of requiring the caller to extract results for every op.
99    Raises ValueError if provided with an op that doesn't have a single result.
100    """
101    if isinstance(arg, _cext.ir.OpView):
102        return arg.operation.result
103    elif isinstance(arg, _cext.ir.Operation):
104        return arg.result
105    elif isinstance(arg, _cext.ir.OpResultList):
106        return arg[0]
107    else:
108        assert isinstance(arg, _cext.ir.Value)
109        return arg
110
111
112def get_op_results_or_values(
113    arg: _Union[
114        _cext.ir.OpView,
115        _cext.ir.Operation,
116        _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
117    ]
118) -> _Union[
119    _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
120    _cext.ir.OpResultList,
121]:
122    """Returns the given sequence of values or the results of the given op.
123
124    This is useful to implement op constructors so that they can take other ops as
125    lists of arguments instead of requiring the caller to extract results for
126    every op.
127    """
128    if isinstance(arg, _cext.ir.OpView):
129        return arg.operation.results
130    elif isinstance(arg, _cext.ir.Operation):
131        return arg.results
132    else:
133        return arg
134
135
136def get_op_result_or_op_results(
137    op: _Union[_cext.ir.OpView, _cext.ir.Operation],
138) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]:
139    results = op.results
140    num_results = len(results)
141    if num_results == 1:
142        return results[0]
143    elif num_results > 1:
144        return results
145    elif isinstance(op, _cext.ir.OpView):
146        return op.operation
147    else:
148        return op
149
150ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
151ResultValueT = _Union[ResultValueTypeTuple]
152VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
153
154StaticIntLike = _Union[int, IntegerAttr]
155ValueLike = _Union[Operation, OpView, Value]
156MixedInt = _Union[StaticIntLike, ValueLike]
157
158IntOrAttrList = _Sequence[_Union[IntegerAttr, int]]
159OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]]
160
161BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]]
162OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]]
163
164MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
165
166DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]]
167
168
169def _dispatch_dynamic_index_list(
170    indices: _Union[DynamicIndexList, ArrayAttr],
171) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]:
172    """Dispatches a list of indices to the appropriate form.
173
174    This is similar to the custom `DynamicIndexList` directive upstream:
175    provided indices may be in the form of dynamic SSA values or static values,
176    and they may be scalable (i.e., as a singleton list) or not. This function
177    dispatches each index into its respective form. It also extracts the SSA
178    values and static indices from various similar structures, respectively.
179    """
180    dynamic_indices = []
181    static_indices = [ShapedType.get_dynamic_size()] * len(indices)
182    scalable_indices = [False] * len(indices)
183
184    # ArrayAttr: Extract index values.
185    if isinstance(indices, ArrayAttr):
186        indices = [idx for idx in indices]
187
188    def process_nonscalable_index(i, index):
189        """Processes any form of non-scalable index.
190
191        Returns False if the given index was scalable and thus remains
192        unprocessed; True otherwise.
193        """
194        if isinstance(index, int):
195            static_indices[i] = index
196        elif isinstance(index, IntegerAttr):
197            static_indices[i] = index.value  # pytype: disable=attribute-error
198        elif isinstance(index, (Operation, Value, OpView)):
199            dynamic_indices.append(index)
200        else:
201            return False
202        return True
203
204    # Process each index at a time.
205    for i, index in enumerate(indices):
206        if not process_nonscalable_index(i, index):
207            # If it wasn't processed, it must be a scalable index, which is
208            # provided as a _Sequence of one value, so extract and process that.
209            scalable_indices[i] = True
210            assert len(index) == 1
211            ret = process_nonscalable_index(i, index[0])
212            assert ret
213
214    return dynamic_indices, static_indices, scalable_indices
215
216
217# Dispatches `MixedValues` that all represents integers in various forms into
218# the following three categories:
219#   - `dynamic_values`: a list of `Value`s, potentially from op results;
220#   - `packed_values`: a value handle, potentially from an op result, associated
221#                      to one or more payload operations of integer type;
222#   - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
223#                      `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
224# The input is in the form for `packed_values`, only that result is set and the
225# other two are empty. Otherwise, the input can be a mix of the other two forms,
226# and for each dynamic value, a special value is added to the `static_values`.
227def _dispatch_mixed_values(
228    values: MixedValues,
229) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]:
230    dynamic_values = []
231    packed_values = None
232    static_values = None
233    if isinstance(values, ArrayAttr):
234        static_values = values
235    elif isinstance(values, (Operation, Value, OpView)):
236        packed_values = values
237    else:
238        static_values = []
239        for size in values or []:
240            if isinstance(size, int):
241                static_values.append(size)
242            else:
243                static_values.append(ShapedType.get_dynamic_size())
244                dynamic_values.append(size)
245        static_values = DenseI64ArrayAttr.get(static_values)
246
247    return (dynamic_values, packed_values, static_values)
248
249
250def _get_value_or_attribute_value(
251    value_or_attr: _Union[any, Attribute, ArrayAttr]
252) -> any:
253    if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
254        return value_or_attr.value
255    if isinstance(value_or_attr, ArrayAttr):
256        return _get_value_list(value_or_attr)
257    return value_or_attr
258
259
260def _get_value_list(
261    sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr]
262) -> _Sequence[any]:
263    return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
264
265
266def _get_int_array_attr(
267    values: _Optional[_Union[ArrayAttr, IntOrAttrList]]
268) -> ArrayAttr:
269    if values is None:
270        return None
271
272    # Turn into a Python list of Python ints.
273    values = _get_value_list(values)
274
275    # Make an ArrayAttr of IntegerAttrs out of it.
276    return ArrayAttr.get(
277        [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
278    )
279
280
281def _get_int_array_array_attr(
282    values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]]
283) -> ArrayAttr:
284    """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
285
286    The input has to be a collection of a collection of integers, where any
287    Python _Sequence and ArrayAttr are admissible collections and Python ints and
288    any IntegerAttr are admissible integers. Both levels of collections are
289    turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
290    If the input is None, an empty ArrayAttr is returned.
291    """
292    if values is None:
293        return None
294
295    # Make sure the outer level is a list.
296    values = _get_value_list(values)
297
298    # The inner level is now either invalid or a mixed sequence of ArrayAttrs and
299    # Sequences. Make sure the nested values are all lists.
300    values = [_get_value_list(nested) for nested in values]
301
302    # Turn each nested list into an ArrayAttr.
303    values = [_get_int_array_attr(nested) for nested in values]
304
305    # Turn the outer list into an ArrayAttr.
306    return ArrayAttr.get(values)
307