xref: /llvm-project/mlir/python/mlir/dialects/memref.py (revision 7c3dfb29dc4b5345da6a7fb25f92bf8d2919bce9)
1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4import operator
5from itertools import accumulate
6from typing import Optional
7
8from ._memref_ops_gen import *
9from ._ods_common import _dispatch_mixed_values, MixedValues
10from .arith import ConstantOp, _is_integer_like_type
11from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
12
13
14def _is_constant_int_like(i):
15    return (
16        isinstance(i, Value)
17        and isinstance(i.owner, Operation)
18        and isinstance(i.owner.opview, ConstantOp)
19        and _is_integer_like_type(i.type)
20    )
21
22
23def _is_static_int_like(i):
24    return (
25        isinstance(i, int) and not ShapedType.is_dynamic_size(i)
26    ) or _is_constant_int_like(i)
27
28
29def _infer_memref_subview_result_type(
30    source_memref_type, offsets, static_sizes, static_strides
31):
32    source_strides, source_offset = source_memref_type.get_strides_and_offset()
33    # "canonicalize" from tuple|list -> list
34    offsets, static_sizes, static_strides, source_strides = map(
35        list, (offsets, static_sizes, static_strides, source_strides)
36    )
37
38    if not all(
39        all(_is_static_int_like(i) for i in s)
40        for s in [
41            static_sizes,
42            static_strides,
43            source_strides,
44        ]
45    ):
46        raise ValueError(
47            "Only inferring from python or mlir integer constant is supported."
48        )
49
50    for s in [offsets, static_sizes, static_strides]:
51        for idx, i in enumerate(s):
52            if _is_constant_int_like(i):
53                s[idx] = i.owner.opview.literal_value
54
55    if any(not _is_static_int_like(i) for i in offsets + [source_offset]):
56        target_offset = ShapedType.get_dynamic_size()
57    else:
58        target_offset = source_offset
59        for offset, target_stride in zip(offsets, source_strides):
60            target_offset += offset * target_stride
61
62    target_strides = []
63    for source_stride, static_stride in zip(source_strides, static_strides):
64        target_strides.append(source_stride * static_stride)
65
66    # If default striding then no need to complicate things for downstream ops (e.g., expand_shape).
67    default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1]
68    if target_strides == default_strides and target_offset == 0:
69        layout = None
70    else:
71        layout = StridedLayoutAttr.get(target_offset, target_strides)
72    return (
73        offsets,
74        static_sizes,
75        static_strides,
76        MemRefType.get(
77            static_sizes,
78            source_memref_type.element_type,
79            layout,
80            source_memref_type.memory_space,
81        ),
82    )
83
84
85_generated_subview = subview
86
87
88def subview(
89    source: Value,
90    offsets: MixedValues,
91    sizes: MixedValues,
92    strides: MixedValues,
93    *,
94    result_type: Optional[MemRefType] = None,
95    loc=None,
96    ip=None,
97):
98    if offsets is None:
99        offsets = []
100    if sizes is None:
101        sizes = []
102    if strides is None:
103        strides = []
104    source_strides, source_offset = source.type.get_strides_and_offset()
105    if result_type is None and all(
106        all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides]
107    ):
108        # If any are arith.constant results then this will canonicalize to python int
109        # (which can then be used to fully specify the subview).
110        (
111            offsets,
112            sizes,
113            strides,
114            result_type,
115        ) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
116    elif result_type is None:
117        raise ValueError(
118            "mixed static/dynamic offset/sizes/strides requires explicit result type."
119        )
120
121    offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
122    sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
123    strides, _packed_strides, static_strides = _dispatch_mixed_values(strides)
124
125    return _generated_subview(
126        result_type,
127        source,
128        offsets,
129        sizes,
130        strides,
131        static_offsets,
132        static_sizes,
133        static_strides,
134        loc=loc,
135        ip=ip,
136    )
137