1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4import operator 5from itertools import accumulate 6from typing import Optional 7 8from ._memref_ops_gen import * 9from ._ods_common import _dispatch_mixed_values, MixedValues 10from .arith import ConstantOp, _is_integer_like_type 11from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation 12 13 14def _is_constant_int_like(i): 15 return ( 16 isinstance(i, Value) 17 and isinstance(i.owner, Operation) 18 and isinstance(i.owner.opview, ConstantOp) 19 and _is_integer_like_type(i.type) 20 ) 21 22 23def _is_static_int_like(i): 24 return ( 25 isinstance(i, int) and not ShapedType.is_dynamic_size(i) 26 ) or _is_constant_int_like(i) 27 28 29def _infer_memref_subview_result_type( 30 source_memref_type, offsets, static_sizes, static_strides 31): 32 source_strides, source_offset = source_memref_type.get_strides_and_offset() 33 # "canonicalize" from tuple|list -> list 34 offsets, static_sizes, static_strides, source_strides = map( 35 list, (offsets, static_sizes, static_strides, source_strides) 36 ) 37 38 if not all( 39 all(_is_static_int_like(i) for i in s) 40 for s in [ 41 static_sizes, 42 static_strides, 43 source_strides, 44 ] 45 ): 46 raise ValueError( 47 "Only inferring from python or mlir integer constant is supported." 48 ) 49 50 for s in [offsets, static_sizes, static_strides]: 51 for idx, i in enumerate(s): 52 if _is_constant_int_like(i): 53 s[idx] = i.owner.opview.literal_value 54 55 if any(not _is_static_int_like(i) for i in offsets + [source_offset]): 56 target_offset = ShapedType.get_dynamic_size() 57 else: 58 target_offset = source_offset 59 for offset, target_stride in zip(offsets, source_strides): 60 target_offset += offset * target_stride 61 62 target_strides = [] 63 for source_stride, static_stride in zip(source_strides, static_strides): 64 target_strides.append(source_stride * static_stride) 65 66 # If default striding then no need to complicate things for downstream ops (e.g., expand_shape). 67 default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1] 68 if target_strides == default_strides and target_offset == 0: 69 layout = None 70 else: 71 layout = StridedLayoutAttr.get(target_offset, target_strides) 72 return ( 73 offsets, 74 static_sizes, 75 static_strides, 76 MemRefType.get( 77 static_sizes, 78 source_memref_type.element_type, 79 layout, 80 source_memref_type.memory_space, 81 ), 82 ) 83 84 85_generated_subview = subview 86 87 88def subview( 89 source: Value, 90 offsets: MixedValues, 91 sizes: MixedValues, 92 strides: MixedValues, 93 *, 94 result_type: Optional[MemRefType] = None, 95 loc=None, 96 ip=None, 97): 98 if offsets is None: 99 offsets = [] 100 if sizes is None: 101 sizes = [] 102 if strides is None: 103 strides = [] 104 source_strides, source_offset = source.type.get_strides_and_offset() 105 if result_type is None and all( 106 all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides] 107 ): 108 # If any are arith.constant results then this will canonicalize to python int 109 # (which can then be used to fully specify the subview). 110 ( 111 offsets, 112 sizes, 113 strides, 114 result_type, 115 ) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides) 116 elif result_type is None: 117 raise ValueError( 118 "mixed static/dynamic offset/sizes/strides requires explicit result type." 119 ) 120 121 offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets) 122 sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes) 123 strides, _packed_strides, static_strides = _dispatch_mixed_values(strides) 124 125 return _generated_subview( 126 result_type, 127 source, 128 offsets, 129 sizes, 130 strides, 131 static_offsets, 132 static_sizes, 133 static_strides, 134 loc=loc, 135 ip=ip, 136 ) 137