xref: /llvm-project/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir (revision 889b67c9d30e3024a1317431d66c22599f6c2011)
1// RUN: mlir-opt --expand-strided-metadata -split-input-file %s -o - | FileCheck %s
2
3// CHECK-LABEL: func @extract_strided_metadata_constants
4//  CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32, strided<[4, 1], offset: 2>>)
5func.func @extract_strided_metadata_constants(%base: memref<5x4xf32, strided<[4, 1], offset: 2>>)
6    -> (memref<f32>, index, index, index, index, index) {
7  //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
8  //   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
9  //   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
10  //   CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
11
12  //       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
13  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %base :
14    memref<5x4xf32, strided<[4,1], offset:2>>
15    -> memref<f32>, index, index, index, index, index
16
17  // CHECK: %[[BASE]], %[[C2]], %[[C5]], %[[C4]], %[[C4]], %[[C1]]
18  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
19    memref<f32>, index, index, index, index, index
20}
21
22// -----
23
24// Check that we simplify subview(src) into:
25// base, offset, sizes, strides xtract_strided_metadata src
26// final_sizes = subSizes
27// final_strides = <some math> strides
28// final_offset = <some math> offset
29// reinterpret_cast base to final_offset, final_sizes, final_ strides
30//
31// Orig strides: [s0, s1, s2]
32// Sub strides: [subS0, subS1, subS2]
33// => New strides: [s0 * subS0, s1 * subS1, s2 * subS2]
34// ==> 1 affine map (used for each stride) with two values.
35//
36// Orig offset: origOff
37// Sub offsets: [subO0, subO1, subO2]
38// => Final offset: s0 * * subO0 + ... + s2 * * subO2 + origOff
39// ==> 1 affine map with (rank * 2 + 1) symbols
40//
41// CHECK-DAG: #[[$STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
42// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)>
43// CHECK-LABEL: func @simplify_subview_all_dynamic
44//  CHECK-SAME: (%[[ARG:.*]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>, %[[DYN_OFFSET0:.*]]: index, %[[DYN_OFFSET1:.*]]: index, %[[DYN_OFFSET2:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index, %[[DYN_STRIDE2:.*]]: index)
45//
46//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
47//
48//  CHECK-DAG: %[[FINAL_STRIDE0:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE0]], %[[STRIDES]]#0]
49//  CHECK-DAG: %[[FINAL_STRIDE1:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE1]], %[[STRIDES]]#1]
50//  CHECK-DAG: %[[FINAL_STRIDE2:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE2]], %[[STRIDES]]#2]
51//
52//  CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[STRIDES]]#2]
53//
54//      CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[FINAL_OFFSET]]], sizes: [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]], strides: [%[[FINAL_STRIDE0]], %[[FINAL_STRIDE1]], %[[FINAL_STRIDE2]]]
55//
56//       CHECK: return %[[RES]]
57func.func @simplify_subview_all_dynamic(
58    %base: memref<?x?x?xf32, strided<[?,?,?], offset:?>>,
59    %offset0: index, %offset1: index, %offset2: index,
60    %size0: index, %size1: index, %size2: index,
61    %stride0: index, %stride1: index, %stride2: index)
62    -> memref<?x?x?xf32, strided<[?,?,?], offset:?>> {
63
64  %subview = memref.subview %base[%offset0, %offset1, %offset2]
65                                 [%size0, %size1, %size2]
66                                 [%stride0, %stride1, %stride2] :
67    memref<?x?x?xf32, strided<[?,?,?], offset: ?>> to
68      memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
69
70  return %subview : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
71}
72
73// -----
74
75// Check that we simplify extract_strided_metadata of subview to
76// base_buf, base_offset, base_sizes, base_strides = extract_strided_metadata
77// strides = base_stride_i * subview_stride_i
78// offset = base_offset + sum(subview_offsets_i * base_strides_i).
79//
80// This test also checks that we don't create useless arith operations
81// when subview_offsets_i is 0.
82//
83// CHECK-LABEL: func @extract_strided_metadata_of_subview
84//  CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32>)
85//
86// Materialize the offset for dimension 1.
87//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
88//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
89//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
90//
91// Plain extract_strided_metadata.
92//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
93//
94// Final offset is:
95//   origOffset + (== 0)
96//   base_stride0 * subview_offset0 + (== 4 * 0 == 0)
97//   base_stride1 * subview_offset1 (== 1 * 2)
98//  == 2
99//
100// Return the new tuple.
101//       CHECK: return %[[BASE]], %[[C2]], %[[C2]], %[[C2]], %[[C4]], %[[C1]]
102func.func @extract_strided_metadata_of_subview(%base: memref<5x4xf32>)
103    -> (memref<f32>, index, index, index, index, index) {
104
105  %subview = memref.subview %base[0, 2][2, 2][1, 1] :
106    memref<5x4xf32> to memref<2x2xf32, strided<[4, 1], offset: 2>>
107
108  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview :
109    memref<2x2xf32, strided<[4,1], offset:2>>
110    -> memref<f32>, index, index, index, index, index
111
112  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
113    memref<f32>, index, index, index, index, index
114}
115
116// -----
117
118// Check that we simplify extract_strided_metadata of subview properly
119// when dynamic sizes are involved.
120// See extract_strided_metadata_of_subview for an explanation of the actual
121// expansion.
122// Orig strides: [64, 4, 1]
123// Sub strides: [1, 1, 1]
124// => New strides: [64, 4, 1]
125//
126// Orig offset: 0
127// Sub offsets: [3, 4, 2]
128// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210
129//
130// Final sizes == subview sizes == [%size, 6, 3]
131//
132// CHECK-LABEL: func @extract_strided_metadata_of_subview_with_dynamic_size
133//  CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>,
134//  CHECK-SAME: %[[DYN_SIZE:.*]]: index)
135//
136//   CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index
137//   CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
138//   CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
139//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
140//   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
141//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
142//
143//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
144//
145//       CHECK: return %[[BASE]], %[[C210]], %[[DYN_SIZE]], %[[C6]], %[[C3]], %[[C64]], %[[C4]], %[[C1]]
146func.func @extract_strided_metadata_of_subview_with_dynamic_size(
147    %base: memref<8x16x4xf32>, %size: index)
148    -> (memref<f32>, index, index, index, index, index, index, index) {
149
150  %subview = memref.subview %base[3, 4, 2][%size, 6, 3][1, 1, 1] :
151    memref<8x16x4xf32> to memref<?x6x3xf32, strided<[64, 4, 1], offset: 210>>
152
153  %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview :
154    memref<?x6x3xf32, strided<[64,4,1], offset: 210>>
155    -> memref<f32>, index, index, index, index, index, index, index
156
157  return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 :
158    memref<f32>, index, index, index, index, index, index, index
159}
160
161// -----
162
163// Check that we simplify extract_strided_metadata of subview properly
164// when the subview reduces the ranks.
165// In particular the returned strides must come from #1 and #2 of the %strides
166// value of the new extract_strided_metadata_of_subview, not #0 and #1.
167// See extract_strided_metadata_of_subview for an explanation of the actual
168// expansion.
169//
170// Orig strides: [64, 4, 1]
171// Sub strides: [1, 1, 1]
172// => New strides: [64, 4, 1]
173// Final strides == filterOutReducedDim(new strides, 0) == [4 , 1]
174//
175// Orig offset: 0
176// Sub offsets: [3, 4, 2]
177// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210
178//
179// Final sizes == filterOutReducedDim(subview sizes, 0) == [6, 3]
180//
181// CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview
182//  CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>)
183//
184//   CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index
185//   CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
186//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
187//   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
188//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
189//
190//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
191//
192//       CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[C4]], %[[C1]]
193func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x4xf32>)
194    -> (memref<f32>, index, index, index, index, index) {
195
196  %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, 1, 1] :
197    memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>>
198
199  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview :
200    memref<6x3xf32, strided<[4,1], offset: 210>>
201    -> memref<f32>, index, index, index, index, index
202
203  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
204    memref<f32>, index, index, index, index, index
205}
206
207// -----
208
209// Check that we simplify extract_strided_metadata of subview properly
210// when the subview reduces the rank and some of the strides are variable.
211// In particular, we check that:
212// A. The dynamic stride is multiplied with the base stride to create the new
213//    stride for dimension 1.
214// B. The first returned stride is the value computed in #A.
215// See extract_strided_metadata_of_subview for an explanation of the actual
216// expansion.
217//
218// Orig strides: [64, 4, 1]
219// Sub strides: [1, %stride, 1]
220// => New strides: [64, 4 * %stride, 1]
221// Final strides == filterOutReducedDim(new strides, 0) == [4 * %stride , 1]
222//
223// Orig offset: 0
224// Sub offsets: [3, 4, 2]
225// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210
226//
227//   CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)>
228// CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides
229//  CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>,
230//  CHECK-SAME: %[[DYN_STRIDE:.*]]: index)
231//
232//   CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index
233//   CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
234//   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
235//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
236//
237//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
238//
239//   CHECK-DAG: %[[DIM1_STRIDE:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_STRIDE]]]
240//
241//       CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]]
242func.func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides(
243    %base: memref<8x16x4xf32>, %stride: index)
244    -> (memref<f32>, index, index, index, index, index) {
245
246  %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, %stride, 1] :
247    memref<8x16x4xf32> to memref<6x3xf32, strided<[?, 1], offset: 210>>
248
249  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview :
250    memref<6x3xf32, strided<[?, 1], offset: 210>>
251    -> memref<f32>, index, index, index, index, index
252
253  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
254    memref<f32>, index, index, index, index, index
255}
256
257// -----
258
259// Check that we simplify extract_strided_metadata of subview properly
260// when the subview uses variable offsets.
261// See extract_strided_metadata_of_subview for an explanation of the actual
262// expansion.
263//
264// Orig strides: [128, 1]
265// Sub strides: [1, 1]
266// => New strides: [128, 1]
267//
268// Orig offset: 0
269// Sub offsets: [%arg1, %arg2]
270// => Final offset: 128 * arg1 + 1 * %arg2 + 0
271//
272//   CHECK-DAG: #[[$OFFSETS_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 128 + s1)>
273// CHECK-LABEL: func @extract_strided_metadata_of_subview_w_variable_offset
274//  CHECK-SAME: (%[[ARG:.*]]: memref<384x128xf32>,
275//  CHECK-SAME: %[[DYN_OFFSET0:.*]]: index,
276//  CHECK-SAME: %[[DYN_OFFSET1:.*]]: index)
277//
278//   CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
279//   CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
280//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
281//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
282//
283//   CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSETS_MAP]]()[%[[DYN_OFFSET0]], %[[DYN_OFFSET1]]]
284//
285//       CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C64]], %[[C64]], %[[C128]], %[[C1]]
286func.func @extract_strided_metadata_of_subview_w_variable_offset(
287    %arg0: memref<384x128xf32>, %arg1 : index, %arg2 : index)
288    -> (memref<f32>, index, index, index, index, index) {
289
290  %subview = memref.subview %arg0[%arg1, %arg2] [64, 64] [1, 1] :
291    memref<384x128xf32> to memref<64x64xf32, strided<[128, 1], offset: ?>>
292
293  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview :
294  memref<64x64xf32, strided<[128, 1], offset: ?>> -> memref<f32>, index, index, index, index, index
295
296  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
297    memref<f32>, index, index, index, index, index
298}
299
300// -----
301
302// Check that all the math is correct for all types of computations.
303// We achieve that by using dynamic values for all the different types:
304// - Offsets
305// - Sizes
306// - Strides
307//
308// Orig strides: [s0, s1, s2]
309// Sub strides: [subS0, subS1, subS2]
310// => New strides: [s0 * subS0, s1 * subS1, s2 * subS2]
311// ==> 1 affine map (used for each stride) with two values.
312//
313// Orig offset: origOff
314// Sub offsets: [subO0, subO1, subO2]
315// => Final offset: s0 * * subO0 + ... + s2 * subO2 + origOff
316// ==> 1 affine map with (rank * 2 + 1) symbols
317//
318// CHECK-DAG: #[[$STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
319// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)>
320// CHECK-LABEL: func @extract_strided_metadata_of_subview_all_dynamic
321//  CHECK-SAME: (%[[ARG:.*]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>, %[[DYN_OFFSET0:.*]]: index, %[[DYN_OFFSET1:.*]]: index, %[[DYN_OFFSET2:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index, %[[DYN_STRIDE2:.*]]: index)
322//
323//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
324//
325//  CHECK-DAG: %[[FINAL_STRIDE0:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE0]], %[[STRIDES]]#0]
326//  CHECK-DAG: %[[FINAL_STRIDE1:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE1]], %[[STRIDES]]#1]
327//  CHECK-DAG: %[[FINAL_STRIDE2:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE2]], %[[STRIDES]]#2]
328//
329//  CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[STRIDES]]#2]
330//
331//       CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]], %[[FINAL_STRIDE0]], %[[FINAL_STRIDE1]], %[[FINAL_STRIDE2]]
332func.func @extract_strided_metadata_of_subview_all_dynamic(
333    %base: memref<?x?x?xf32, strided<[?,?,?], offset:?>>,
334    %offset0: index, %offset1: index, %offset2: index,
335    %size0: index, %size1: index, %size2: index,
336    %stride0: index, %stride1: index, %stride2: index)
337    -> (memref<f32>, index, index, index, index, index, index, index) {
338
339  %subview = memref.subview %base[%offset0, %offset1, %offset2]
340                                 [%size0, %size1, %size2]
341                                 [%stride0, %stride1, %stride2] :
342    memref<?x?x?xf32, strided<[?,?,?], offset: ?>> to
343      memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
344
345  %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview :
346    memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
347    -> memref<f32>, index, index, index, index, index, index, index
348
349  return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 :
350    memref<f32>, index, index, index, index, index, index, index
351}
352
353// -----
354
355// Check that we properly simplify expand_shape into:
356// reinterpret_cast(extract_strided_metadata) + <some math>
357//
358// Here we have:
359// For the group applying to dim0:
360// size 0 = baseSizes#0 / (all static sizes in that group)
361//        = baseSizes#0 / (7 * 8 * 9)
362//        = baseSizes#0 / 504
363// size 1 = 7
364// size 2 = 8
365// size 3 = 9
366// stride 0 = baseStrides#0 * 7 * 8 * 9
367//          = baseStrides#0 * 504
368// stride 1 = baseStrides#0 * 8 * 9
369//          = baseStrides#0 * 72
370// stride 2 = baseStrides#0 * 9
371// stride 3 = baseStrides#0
372//
373// For the group applying to dim1:
374// size 4 = 10
375// size 5 = 2
376// size 6 = baseSizes#1 / (all static sizes in that group)
377//        = baseSizes#1 / (10 * 2 * 3)
378//        = baseSizes#1 / 60
379// size 7 = 3
380// stride 4 = baseStrides#1 * size 5 * size 6 * size 7
381//          = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
382//          = baseStrides#1 * (baseSizes#1 / 60) * 6
383//          and since we know that baseSizes#1 is a multiple of 60:
384//          = baseStrides#1 * (baseSizes#1 / 10)
385// stride 5 = baseStrides#1 * size 6 * size 7
386//          = baseStrides#1 * (baseSizes#1 / 60) * 3
387//          = baseStrides#1 * (baseSizes#1 / 20)
388// stride 6 = baseStrides#1 * size 7
389//          = baseStrides#1 * 3
390// stride 7 = baseStrides#1
391//
392// Base and offset are unchanged.
393//
394//   CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
395//   CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
396//
397//   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
398//   CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
399//   CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
400//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
401//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
402//   CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
403// CHECK-LABEL: func @simplify_expand_shape
404//  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
405//
406//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
407//
408//   CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
409//   CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
410//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
411//   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
412//   CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
413//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
414//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
415//   CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
416//
417//   CHECK-DAG: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[DYN_SIZE0]], 7, 8, 9, 10, 2, %[[DYN_SIZE6]], 3], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1]
418//
419//   CHECK: return %[[REINTERPRET_CAST]]
420func.func @simplify_expand_shape(
421    %base: memref<?x?xf32, strided<[?,?], offset:?>>,
422    %offset0: index, %offset1: index, %offset2: index,
423    %size0: index, %size1: index, %size2: index,
424    %stride0: index, %stride1: index, %stride2: index,
425    %sz0: index, %sz1: index)
426    -> memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>> {
427
428  %subview = memref.expand_shape %base [[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] :
429    memref<?x?xf32, strided<[?,?], offset: ?>> into
430      memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
431
432  return %subview :
433    memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
434}
435
436// -----
437
438// Check that we properly simplify extract_strided_metadata of expand_shape
439// into:
440// baseBuffer, baseOffset, baseSizes, baseStrides =
441//     extract_strided_metadata(memref)
442// sizes#reassIdx =
443//     baseSizes#reassDim / product(expandShapeSizes#j,
444//                                  for j in group excluding reassIdx)
445// strides#reassIdx =
446//     baseStrides#reassDim * product(expandShapeSizes#j, for j in
447//                                    reassIdx+1..reassIdx+group.size)
448//
449// Here we have:
450// For the group applying to dim0:
451// size 0 = 3
452// size 1 = 5
453// size 2 = 2
454// stride 0 = baseStrides#0 * 5 * 2
455//          = 4 * 5 * 2
456//          = 40
457// stride 1 = baseStrides#0 * 2
458//          = 4 * 2
459//          = 8
460// stride 2 = baseStrides#0
461//          = 4
462//
463// For the group applying to dim1:
464// size 3 = 2
465// size 4 = 2
466// stride 3 = baseStrides#1 * 2
467//          = 1 * 2
468//          = 2
469// stride 4 = baseStrides#1
470//          = 1
471//
472// Base and offset are unchanged.
473//
474// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_static
475//  CHECK-SAME: (%[[ARG:.*]]: memref<30x4xi16>)
476//
477//   CHECK-DAG: %[[C40:.*]] = arith.constant 40 : index
478//   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
479//   CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
480//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
481//   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
482//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
483//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
484//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
485//
486//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<30x4xi16> -> memref<i16>, index, index, index, index, index
487//
488//   CHECK: return %[[BASE]], %[[C0]], %[[C3]], %[[C5]], %[[C2]], %[[C2]], %[[C2]], %[[C40]], %[[C8]], %[[C4]], %[[C2]], %[[C1]] : memref<i16>, index, index, index, index, index, index, index, index, index, index, index
489func.func @extract_strided_metadata_of_expand_shape_all_static(
490    %arg : memref<30x4xi16>)
491    -> (memref<i16>, index,
492       index, index, index, index, index,
493       index, index, index, index, index) {
494
495  %expand_shape = memref.expand_shape %arg[[0, 1, 2], [3, 4]] output_shape [3, 5, 2, 2, 2] :
496    memref<30x4xi16> into memref<3x5x2x2x2xi16>
497
498  %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape :
499    memref<3x5x2x2x2xi16>
500    -> memref<i16>, index,
501       index, index, index, index, index,
502       index, index, index, index, index
503
504  return %base, %offset,
505    %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4,
506    %strides#0, %strides#1, %strides#2, %strides#3, %strides#4 :
507      memref<i16>, index,
508      index, index, index, index, index,
509      index, index, index, index, index
510}
511
512// -----
513
514// Check that we properly simplify extract_strided_metadata of expand_shape
515// when dynamic sizes, strides, and offsets are involved.
516// See extract_strided_metadata_of_expand_shape_all_static for an explanation
517// of the expansion.
518//
519// One of the important characteristic of this test is that the dynamic
520// dimensions produced by the expand_shape appear both in the first dimension
521// (for group 1) and the non-first dimension (second dimension for group 2.)
522// The idea is to make sure that:
523// 1. We properly account for dynamic shapes even when the strides are not
524//    affected by them. (When the dynamic dimension is the first one.)
525// 2. We properly compute the strides affected by dynamic shapes. (When the
526//    dynamic dimension is not the first one.)
527//
528// Here we have:
529// For the group applying to dim0:
530// size 0 = baseSizes#0 / (all static sizes in that group)
531//        = baseSizes#0 / (7 * 8 * 9)
532//        = baseSizes#0 / 504
533// size 1 = 7
534// size 2 = 8
535// size 3 = 9
536// stride 0 = baseStrides#0 * 7 * 8 * 9
537//          = baseStrides#0 * 504
538// stride 1 = baseStrides#0 * 8 * 9
539//          = baseStrides#0 * 72
540// stride 2 = baseStrides#0 * 9
541// stride 3 = baseStrides#0
542//
543// For the group applying to dim1:
544// size 4 = 10
545// size 5 = 2
546// size 6 = baseSizes#1 / (all static sizes in that group)
547//        = baseSizes#1 / (10 * 2 * 3)
548//        = baseSizes#1 / 60
549// size 7 = 3
550// stride 4 = baseStrides#1 * size 5 * size 6 * size 7
551//          = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
552//          = baseStrides#1 * (baseSizes#1 / 60) * 6
553//          and since we know that baseSizes#1 is a multiple of 60:
554//          = baseStrides#1 * (baseSizes#1 / 10)
555// stride 5 = baseStrides#1 * size 6 * size 7
556//          = baseStrides#1 * (baseSizes#1 / 60) * 3
557//          = baseStrides#1 * (baseSizes#1 / 20)
558// stride 6 = baseStrides#1 * size 7
559//          = baseStrides#1 * 3
560// stride 7 = baseStrides#1
561//
562// Base and offset are unchanged.
563//
564//   CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
565//   CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
566//
567//   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
568//   CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
569//   CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
570//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
571//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
572//   CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
573// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_dynamic
574//  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
575//
576//   CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
577//   CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
578//   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
579//   CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
580//   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
581//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
582//
583//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
584//
585//   CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
586//   CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
587//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
588//   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
589//   CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
590//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
591//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
592//   CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
593
594//   CHECK: return %[[BASE]], %[[OFFSET]], %[[DYN_SIZE0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[C2]], %[[DYN_SIZE6]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index
595func.func @extract_strided_metadata_of_expand_shape_all_dynamic(
596    %base: memref<?x?xf32, strided<[?,?], offset:?>>,
597    %offset0: index, %offset1: index, %offset2: index,
598    %size0: index, %size1: index, %size2: index,
599    %stride0: index, %stride1: index, %stride2: index,
600    %sz0: index, %sz1: index)
601    -> (memref<f32>, index,
602       index, index, index, index, index, index, index, index,
603       index, index, index, index, index, index, index, index) {
604
605  %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] :
606    memref<?x?xf32, strided<[?,?], offset: ?>> into
607      memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
608
609  %base_buffer, %offset, %sizes:8, %strides:8 = memref.extract_strided_metadata %subview :
610    memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
611    -> memref<f32>, index,
612       index, index, index, index, index, index, index, index,
613       index, index, index, index, index, index, index, index
614
615  return %base_buffer, %offset,
616    %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4, %sizes#5, %sizes#6, %sizes#7,
617    %strides#0, %strides#1, %strides#2, %strides#3, %strides#4, %strides#5, %strides#6, %strides#7 :
618      memref<f32>, index,
619      index, index, index, index, index, index, index, index,
620      index, index, index, index, index, index, index, index
621}
622
623
624// -----
625
626// Check that we properly handle extract_strided_metadata of expand_shape for
627// 0-D input.
628// The 0-D case is pretty boring:
629// All expanded sizes are 1, likewise for the strides, and we keep the
630// original base and offset.
631// We have still a test for it, because since the input reassociation map
632// of the expand_shape is empty, the handling of such shape hits a corner
633// case.
634// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_static_0_rank
635//  CHECK-SAME: (%[[ARG:.*]]: memref<i16, strided<[], offset: ?>>)
636//
637//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
638//
639//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]] : memref<i16, strided<[], offset: ?>> -> memref<i16>, index
640//
641//   CHECK: return %[[BASE]], %[[OFFSET]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]] : memref<i16>, index, index, index, index, index, index, index, index, index, index, index
642func.func @extract_strided_metadata_of_expand_shape_all_static_0_rank(
643    %arg : memref<i16, strided<[], offset: ?>>)
644    -> (memref<i16>, index,
645       index, index, index, index, index,
646       index, index, index, index, index) {
647
648  %expand_shape = memref.expand_shape %arg[] output_shape [1, 1, 1, 1, 1] :
649    memref<i16, strided<[], offset: ?>> into memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>>
650
651  %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape :
652    memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>>
653    -> memref<i16>, index,
654       index, index, index, index, index,
655       index, index, index, index, index
656
657  return %base, %offset,
658    %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4,
659    %strides#0, %strides#1, %strides#2, %strides#3, %strides#4 :
660      memref<i16>, index,
661      index, index, index, index, index,
662      index, index, index, index, index
663}
664
665// -----
666
667// Check that we simplify extract_strided_metadata(alloc)
668// into simply the alloc with the information extracted from
669// the memref type and arguments of the alloc.
670//
671// baseBuffer = reinterpret_cast alloc
672// offset = 0
673// sizes = shape(memref)
674// strides = strides(memref)
675//
676// For dynamic shapes, we simply use the values that feed the alloc.
677//
678// Simple rank 0 test: we don't need a reinterpret_cast here.
679// CHECK-LABEL: func @extract_strided_metadata_of_alloc_all_static_0_rank
680//
681//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
682//   CHECK-DAG: %[[ALLOC:.*]] = memref.alloc()
683//       CHECK: return %[[ALLOC]], %[[C0]] : memref<i16>, index
684func.func @extract_strided_metadata_of_alloc_all_static_0_rank()
685    -> (memref<i16>, index) {
686
687  %A = memref.alloc() : memref<i16>
688  %base, %offset = memref.extract_strided_metadata %A :
689    memref<i16>
690    -> memref<i16>, index
691
692  return %base, %offset :
693      memref<i16>, index
694}
695
696// -----
697
698// Simplification of extract_strided_metadata(alloc).
699// Check that we properly use the dynamic sizes to
700// create the new sizes and strides.
701// size 0 = dyn_size0
702// size 1 = 4
703// size 2 = dyn_size2
704// size 3 = dyn_size3
705//
706// stride 0 = size 1 * size 2 * size 3
707//          = 4 * dyn_size2 * dyn_size3
708// stride 1 = size 2 * size 3
709//          = dyn_size2 * dyn_size3
710// stride 2 = size 3
711//          = dyn_size3
712// stride 3 = 1
713//
714//   CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
715//   CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
716// CHECK-LABEL: extract_strided_metadata_of_alloc_dyn_size
717//  CHECK-SAME: (%[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_SIZE3:.*]]: index)
718//
719//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
720//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
721//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
722//   CHECK-DAG: %[[ALLOC:.*]] = memref.alloc(%[[DYN_SIZE0]], %[[DYN_SIZE2]], %[[DYN_SIZE3]])
723//
724//   CHECK-DAG: %[[STRIDE0:.*]] = affine.apply #[[$STRIDE0_MAP]]()[%[[DYN_SIZE2]], %[[DYN_SIZE3]]]
725//   CHECK-DAG: %[[STRIDE1:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_SIZE2]], %[[DYN_SIZE3]]]
726//
727//   CHECK-DAG:  %[[CASTED_ALLOC:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [], strides: [] : memref<?x4x?x?xi16> to memref<i16>
728//
729//       CHECK: return %[[CASTED_ALLOC]], %[[C0]], %[[DYN_SIZE0]], %[[C4]], %[[DYN_SIZE2]], %[[DYN_SIZE3]], %[[STRIDE0]], %[[STRIDE1]], %[[DYN_SIZE3]], %[[C1]]
730func.func @extract_strided_metadata_of_alloc_dyn_size(
731  %dyn_size0 : index, %dyn_size2 : index, %dyn_size3 : index)
732    -> (memref<i16>, index,
733        index, index, index, index,
734        index, index, index, index) {
735
736  %A = memref.alloc(%dyn_size0, %dyn_size2, %dyn_size3) : memref<?x4x?x?xi16>
737
738  %base, %offset, %sizes:4, %strides:4 = memref.extract_strided_metadata %A :
739    memref<?x4x?x?xi16>
740    -> memref<i16>, index,
741       index, index, index, index,
742       index, index, index, index
743
744  return %base, %offset,
745    %sizes#0, %sizes#1, %sizes#2, %sizes#3,
746    %strides#0, %strides#1, %strides#2, %strides#3 :
747      memref<i16>, index,
748      index, index, index, index,
749      index, index, index, index
750}
751
752// -----
753
754// Same check as extract_strided_metadata_of_alloc_dyn_size but alloca
755// instead of alloc. Just to make sure we handle allocas the same way
756// we do with alloc.
757// While at it, test a slightly different shape than
758// extract_strided_metadata_of_alloc_dyn_size.
759//
760// size 0 = dyn_size0
761// size 1 = dyn_size1
762// size 2 = 4
763// size 3 = dyn_size3
764//
765// stride 0 = size 1 * size 2 * size 3
766//          = dyn_size1 * 4 * dyn_size3
767// stride 1 = size 2 * size 3
768//          = 4 * dyn_size3
769// stride 2 = size 3
770//          = dyn_size3
771// stride 3 = 1
772//
773//   CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
774//   CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)>
775// CHECK-LABEL: extract_strided_metadata_of_alloca_dyn_size
776//  CHECK-SAME: (%[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE3:.*]]: index)
777//
778//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
779//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
780//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
781//   CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca(%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE3]])
782//
783//   CHECK-DAG: %[[STRIDE0:.*]] = affine.apply #[[$STRIDE0_MAP]]()[%[[DYN_SIZE1]], %[[DYN_SIZE3]]]
784//   CHECK-DAG: %[[STRIDE1:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_SIZE3]]]
785//
786//   CHECK-DAG:  %[[CASTED_ALLOCA:.*]] = memref.reinterpret_cast %[[ALLOCA]] to offset: [0], sizes: [], strides: [] : memref<?x?x4x?xi16> to memref<i16>
787//
788//       CHECK: return %[[CASTED_ALLOCA]], %[[C0]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[C4]], %[[DYN_SIZE3]], %[[STRIDE0]], %[[STRIDE1]], %[[DYN_SIZE3]], %[[C1]]
789func.func @extract_strided_metadata_of_alloca_dyn_size(
790  %dyn_size0 : index, %dyn_size1 : index, %dyn_size3 : index)
791    -> (memref<i16>, index,
792        index, index, index, index,
793        index, index, index, index) {
794
795  %A = memref.alloca(%dyn_size0, %dyn_size1, %dyn_size3) : memref<?x?x4x?xi16>
796
797  %base, %offset, %sizes:4, %strides:4 = memref.extract_strided_metadata %A :
798    memref<?x?x4x?xi16>
799    -> memref<i16>, index,
800       index, index, index, index,
801       index, index, index, index
802
803  return %base, %offset,
804    %sizes#0, %sizes#1, %sizes#2, %sizes#3,
805    %strides#0, %strides#1, %strides#2, %strides#3 :
806      memref<i16>, index,
807      index, index, index, index,
808      index, index, index, index
809}
810
811// -----
812
813// The following few alloc tests are negative tests (the simplification
814// doesn't happen) to make sure non trivial memref types are treated
815// as "not been normalized".
816// CHECK-LABEL: extract_strided_metadata_of_alloc_with_variable_offset
817//       CHECK: %[[ALLOC:.*]] = memref.alloc
818//       CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]]
819//       CHECK: return %[[BASE]]
820#map0 = affine_map<(d0)[s0] -> (d0 + s0)>
821func.func @extract_strided_metadata_of_alloc_with_variable_offset(%arg : index)
822    -> (memref<i16>, index, index, index) {
823
824  %A = memref.alloc()[%arg] : memref<4xi16, #map0>
825  %base, %offset, %size, %stride = memref.extract_strided_metadata %A :
826    memref<4xi16, #map0>
827    -> memref<i16>, index, index, index
828
829  return %base, %offset, %size, %stride :
830      memref<i16>, index, index, index
831}
832
833// -----
834
835// CHECK-LABEL: extract_strided_metadata_of_alloc_with_cst_offset
836//       CHECK: %[[ALLOC:.*]] = memref.alloc
837//       CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]]
838//       CHECK: return %[[BASE]]
839#map0 = affine_map<(d0) -> (d0 + 12)>
840func.func @extract_strided_metadata_of_alloc_with_cst_offset(%arg : index)
841    -> (memref<i16>, index, index, index) {
842
843  %A = memref.alloc() : memref<4xi16, #map0>
844  %base, %offset, %size, %stride = memref.extract_strided_metadata %A :
845    memref<4xi16, #map0>
846    -> memref<i16>, index, index, index
847
848  return %base, %offset, %size, %stride :
849      memref<i16>, index, index, index
850}
851
852// -----
853
854// CHECK-LABEL: extract_strided_metadata_of_alloc_with_cst_offset_in_type
855//       CHECK: %[[ALLOC:.*]] = memref.alloc
856//       CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]]
857//       CHECK: return %[[BASE]]
858func.func @extract_strided_metadata_of_alloc_with_cst_offset_in_type(%arg : index)
859    -> (memref<i16>, index, index, index) {
860
861  %A = memref.alloc() : memref<4xi16, strided<[1], offset : 10>>
862  %base, %offset, %size, %stride = memref.extract_strided_metadata %A :
863    memref<4xi16, strided<[1], offset : 10>>
864    -> memref<i16>, index, index, index
865
866  return %base, %offset, %size, %stride :
867      memref<i16>, index, index, index
868}
869
870// -----
871
872// CHECK-LABEL: extract_strided_metadata_of_alloc_with_strided
873//       CHECK: %[[ALLOC:.*]] = memref.alloc
874//       CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]]
875//       CHECK: return %[[BASE]]
876func.func @extract_strided_metadata_of_alloc_with_strided(%arg : index)
877    -> (memref<i16>, index, index, index) {
878
879  %A = memref.alloc() : memref<4xi16, strided<[12]>>
880  %base, %offset, %size, %stride = memref.extract_strided_metadata %A :
881    memref<4xi16, strided<[12]>>
882    -> memref<i16>, index, index, index
883
884  return %base, %offset, %size, %stride :
885      memref<i16>, index, index, index
886}
887
888// -----
889
890// CHECK-LABEL: extract_aligned_pointer_as_index
891//  CHECK-SAME: (%[[ARG0:.*]]: memref<f32>
892func.func @extract_aligned_pointer_as_index(%arg0: memref<f32>) -> index {
893  // CHECK-NOT: memref.subview
894  //     CHECK: memref.extract_aligned_pointer_as_index %[[ARG0]]
895  %c = memref.subview %arg0[] [] [] : memref<f32> to memref<f32>
896  %r = memref.extract_aligned_pointer_as_index %arg0: memref<f32> -> index
897  return %r : index
898}
899
900// -----
901
902// CHECK-LABEL: extract_aligned_pointer_as_index_of_unranked_source
903//  CHECK-SAME: (%[[ARG0:.*]]: memref<*xf32>
904func.func @extract_aligned_pointer_as_index_of_unranked_source(%arg0: memref<*xf32>) -> index {
905  // CHECK: %[[I:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<*xf32> -> index
906  // CHECK: return %[[I]]
907
908  %r = memref.reinterpret_cast %arg0 to offset: [0], sizes: [], strides: [] : memref<*xf32> to memref<f32>
909  %i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index
910  return %i : index
911}
912
913// -----
914
915// Check that we simplify collapse_shape into
916// reinterpret_cast(extract_strided_metadata) + <some math>
917//
918// We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5]
919// Size 0 = origSize0
920// Size 1 = origSize1 * origSize2 * origSize3
921//        = origSize1 * 4 * origSize3
922// Size 2 = origSize4 * origSize5
923//        = 6 * 7
924//        = 42
925// Stride 0 = min(origStride0)
926//          = Right now the folder of affine.min is not smart
927//            enough to just return origStride0
928// Stride 1 = min(origStride1, origStride2, origStride3)
929//          = min(origStride1, origStride2, 42)
930// Stride 2 = min(origStride4, origStride5)
931//          = min(7, 1)
932//          = 1
933//
934//       CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
935// CHECK-LABEL: func @simplify_collapse(
936//  CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
937//
938//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
939//
940//       CHECK: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
941//
942//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, 42, 1]
943func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
944  -> memref<?x?x42xi32> {
945
946  %collapsed_view = memref.collapse_shape %arg [[0], [1, 2, 3], [4, 5]] :
947    memref<?x?x4x?x6x7xi32> into memref<?x?x42xi32>
948
949  return %collapsed_view : memref<?x?x42xi32>
950
951}
952
953// -----
954
955// Check that we simplify collapse_shape into
956// reinterpret_cast(extract_strided_metadata) + <some math>
957// when there are dimensions of size 1 involved.
958//
959// We transform: 3x1 to [0, 1]
960//
961// The tricky bit here is the strides between dimension 0 and 1
962// are not truly contiguous, but since we dealing with a dimension of size 1
963// this is actually fine (i.e., we are not going to jump around.)
964//
965// As a result the resulting stride needs to ignore the strides of the
966// dimensions of size 1.
967//
968// Size 0 = origSize0 * origSize1
969//        = 3 * 1
970//        = 3
971// Stride 0 = min(origStride_i, for all i in reassocation group and dim_i != 1)
972//          = min(origStride0)
973//          = min(2)
974//          = 2
975//
976// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1(
977//  CHECK-SAME: %[[ARG:.*]]: memref<3x1xf32, strided<[2, 1]>>,
978//
979//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x1xf32, strided<[2, 1]>>
980//
981//
982//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [3], strides: [2]
983func.func @simplify_collapse_with_dim_of_size1(%arg0: memref<3x1xf32, strided<[2,1]>>, %arg1: memref<3xf32>) {
984
985  %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
986    memref<3x1xf32, strided<[2, 1]>> into memref<3xf32, strided<[2]>>
987
988  memref.copy %collapse_shape, %arg1 : memref<3xf32, strided<[2]>> to memref<3xf32>
989
990  return
991}
992
993
994// -----
995
996// Check that we simplify collapse_shape with an edge case group of 1x1x...x1.
997//
998// The tricky bit here is also the resulting stride is meaningless, we still
999// have to please the type system.
1000//
1001// In this case, we're collapsing two strides of respectively 2 and 1 and the
1002// resulting type wants a stride of 2.
1003//
1004// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_non_1_stride(
1005//  CHECK-SAME: %[[ARG:.*]]: memref<1x1xi32, strided<[2, 1]
1006//
1007//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<1x1xi32, strided<[2, 1], offset: ?>>
1008//
1009//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [1], strides: [2]
1010func.func @simplify_collapse_with_dim_of_size1_and_non_1_stride
1011    (%arg0: memref<1x1xi32, strided<[2, 1], offset: ?>>)
1012    -> memref<1xi32, strided<[2], offset: ?>> {
1013
1014  %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
1015    memref<1x1xi32, strided<[2, 1], offset: ?>>
1016    into memref<1xi32, strided<[2], offset: ?>>
1017
1018  return %collapse_shape : memref<1xi32, strided<[2], offset: ?>>
1019}
1020
1021// -----
1022
1023// Check that we simplify collapse_shape with an edge case group of 1x1x...x1.
1024// We also have a couple of collapsed dimensions before the 1x1x...x1 group
1025// to make sure we properly index into the dynamic strides based on the
1026// group ID.
1027//
1028// The tricky bit in this test is that the 1x1x...x1 group stride is dynamic
1029// so we have to propagate one of the dynamic dimension for this group.
1030//
1031// For this test we have:
1032// Size0 = origSize0 * origSize1
1033//       = 2 * 3
1034//       = 6
1035// Size1 = origSize2 * origSize3 * origSize4
1036//       = 1 * 1 * 1
1037//       = 1
1038//
1039// Stride0 = min(origStride0, origStride1)
1040// Stride1 = we actually don't know, this is dynamic but we don't know
1041//           which one to pick.
1042//           We just return the first dynamic one for this group.
1043//
1044//
1045// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride(
1046//  CHECK-SAME: %[[ARG:.*]]: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2]
1047//
1048//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:5, %[[STRIDES:.*]]:5 = memref.extract_strided_metadata %[[ARG]] : memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>
1049//
1050//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[STRIDES]]#1, %[[STRIDES]]#2]
1051func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
1052    (%arg0: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>)
1053    -> memref<6x1xi32, strided<[?, ?], offset: ?>> {
1054
1055  %collapse_shape = memref.collapse_shape %arg0 [[0, 1], [2, 3, 4]] :
1056    memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>
1057    into memref<6x1xi32, strided<[?, ?], offset: ?>>
1058
1059  return %collapse_shape : memref<6x1xi32, strided<[?, ?], offset: ?>>
1060}
1061
1062// -----
1063
1064// Check that we simplify extract_strided_metadata of collapse_shape.
1065//
1066// We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5]
1067// Size 0 = origSize0
1068// Size 1 = origSize1 * origSize2 * origSize3
1069//        = origSize1 * 4 * origSize3
1070// Size 2 = origSize4 * origSize5
1071//        = 6 * 7
1072//        = 42
1073// Stride 0 = origStride0
1074// Stride 1 = origStride3 (orig stride of the inner most dimension)
1075//          = 42
1076// Stride 2 = origStride5
1077//          = 1
1078//
1079//       CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
1080// CHECK-LABEL: func @extract_strided_metadata_of_collapse(
1081//  CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
1082//
1083//   CHECK-DAG: %[[C42:.*]] = arith.constant 42 : index
1084//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
1085//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
1086//
1087//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
1088//
1089//   CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
1090//
1091//       CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[STRIDES]]#0, %[[C42]], %[[C1]]
1092func.func @extract_strided_metadata_of_collapse(%arg : memref<?x?x4x?x6x7xi32>)
1093  -> (memref<i32>, index,
1094      index, index, index,
1095      index, index, index) {
1096
1097  %collapsed_view = memref.collapse_shape %arg [[0], [1, 2, 3], [4, 5]] :
1098    memref<?x?x4x?x6x7xi32> into memref<?x?x42xi32>
1099
1100  %base, %offset, %sizes:3, %strides:3 =
1101    memref.extract_strided_metadata %collapsed_view : memref<?x?x42xi32>
1102    -> memref<i32>, index,
1103       index, index, index,
1104       index, index, index
1105
1106  return %base, %offset,
1107    %sizes#0, %sizes#1, %sizes#2,
1108    %strides#0, %strides#1, %strides#2 :
1109      memref<i32>, index,
1110      index, index, index,
1111      index, index, index
1112
1113}
1114
1115// -----
1116
1117// Check that we simplify extract_strided_metadata of collapse_shape to
1118// a 0-ranked shape.
1119// CHECK-LABEL: func @extract_strided_metadata_of_collapse_to_rank0(
1120//  CHECK-SAME: %[[ARG:.*]]: memref<1x1x1x1x1x1xi32>)
1121//
1122//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
1123//
1124//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<1x1x1x1x1x1xi32>
1125//
1126//       CHECK: return %[[BASE]], %[[C0]]
1127func.func @extract_strided_metadata_of_collapse_to_rank0(%arg : memref<1x1x1x1x1x1xi32>)
1128  -> (memref<i32>, index) {
1129
1130  %collapsed_view = memref.collapse_shape %arg [] :
1131    memref<1x1x1x1x1x1xi32> into memref<i32>
1132
1133  %base, %offset =
1134    memref.extract_strided_metadata %collapsed_view : memref<i32>
1135    -> memref<i32>, index
1136
1137  return %base, %offset :
1138      memref<i32>, index
1139}
1140
1141// -----
1142
1143// Check that we simplify extract_strided_metadata of
1144// extract_strided_metadata.
1145//
1146// CHECK-LABEL: func @extract_strided_metadata_of_extract_strided_metadata(
1147//  CHECK-SAME: %[[ARG:.*]]: memref<i32>)
1148//
1149//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
1150//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]]
1151//
1152//       CHECK: return %[[BASE]], %[[C0]]
1153func.func @extract_strided_metadata_of_extract_strided_metadata(%arg : memref<i32>)
1154  -> (memref<i32>, index) {
1155
1156  %base, %offset =
1157    memref.extract_strided_metadata %arg:memref<i32>
1158    -> memref<i32>, index
1159  %base2, %offset2 =
1160    memref.extract_strided_metadata %base:memref<i32>
1161    -> memref<i32>, index
1162
1163  return %base2, %offset2 :
1164      memref<i32>, index
1165}
1166
1167// -----
1168
1169// Check that we simplify extract_strided_metadata of reinterpret_cast
1170// when the source of the reinterpret_cast is compatible with what
1171// `extract_strided_metadata`s accept.
1172//
1173// When we apply the transformation the resulting offset, sizes and strides
1174// should come straight from the inputs of the reinterpret_cast.
1175//
1176// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast
1177//  CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index)
1178//
1179//       CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}:2, %{{.*}}:2 = memref.extract_strided_metadata %[[ARG]]
1180//
1181//       CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]]
1182func.func @extract_strided_metadata_of_reinterpret_cast(
1183  %arg : memref<?x?xi32, strided<[?, ?], offset:?>>,
1184  %offset: index,
1185  %size0 : index, %size1 : index,
1186  %stride0 : index, %stride1 : index)
1187  -> (memref<i32>, index,
1188      index, index,
1189      index, index) {
1190
1191  %cast =
1192    memref.reinterpret_cast %arg to
1193      offset: [%offset],
1194      sizes: [%size0, %size1],
1195      strides: [%stride0, %stride1] :
1196      memref<?x?xi32, strided<[?, ?], offset: ?>> to
1197      memref<?x?xi32, strided<[?, ?], offset: ?>>
1198
1199  %base, %base_offset, %sizes:2, %strides:2 =
1200    memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
1201    -> memref<i32>, index,
1202       index, index,
1203       index, index
1204
1205  return %base, %base_offset,
1206    %sizes#0, %sizes#1,
1207    %strides#0, %strides#1 :
1208      memref<i32>, index,
1209      index, index,
1210      index, index
1211}
1212
1213// -----
1214
1215// Check that we don't simplify extract_strided_metadata of
1216// reinterpret_cast when the source of the cast is unranked.
1217// Unranked memrefs cannot feed into extract_strided_metadata operations.
1218// Note: Technically we could still fold the sizes and strides.
1219//
1220// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast_unranked
1221//  CHECK-SAME: %[[ARG:.*]]: memref<*xi32>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index)
1222//
1223//       CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[DYN_OFFSET]]], sizes: [%[[DYN_SIZE0]], %[[DYN_SIZE1]]], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]]]
1224//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
1225//
1226//       CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
1227func.func @extract_strided_metadata_of_reinterpret_cast_unranked(
1228  %arg : memref<*xi32>,
1229  %offset: index,
1230  %size0 : index, %size1 : index,
1231  %stride0 : index, %stride1 : index)
1232  -> (memref<i32>, index,
1233      index, index,
1234      index, index) {
1235
1236  %cast =
1237    memref.reinterpret_cast %arg to
1238      offset: [%offset],
1239      sizes: [%size0, %size1],
1240      strides: [%stride0, %stride1] :
1241      memref<*xi32> to
1242      memref<?x?xi32, strided<[?, ?], offset: ?>>
1243
1244  %base, %base_offset, %sizes:2, %strides:2 =
1245    memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
1246    -> memref<i32>, index,
1247       index, index,
1248       index, index
1249
1250  return %base, %base_offset,
1251    %sizes#0, %sizes#1,
1252    %strides#0, %strides#1 :
1253      memref<i32>, index,
1254      index, index,
1255      index, index
1256}
1257
1258// -----
1259
1260// Similar to @extract_strided_metadata_of_reinterpret_cast, just make sure
1261// we handle 0-D properly.
1262//
1263// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast_rank0
1264//  CHECK-SAME: %[[ARG:.*]]: memref<i32, strided<[], offset: ?>>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index)
1265//
1266//       CHECK: %[[BASE:.*]], %[[BASE_OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]]
1267//
1268//       CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]]
1269func.func @extract_strided_metadata_of_reinterpret_cast_rank0(
1270  %arg : memref<i32, strided<[], offset:?>>,
1271  %offset: index,
1272  %size0 : index, %size1 : index,
1273  %stride0 : index, %stride1 : index)
1274  -> (memref<i32>, index,
1275      index, index,
1276      index, index) {
1277
1278  %cast =
1279    memref.reinterpret_cast %arg to
1280      offset: [%offset],
1281      sizes: [%size0, %size1],
1282      strides: [%stride0, %stride1] :
1283      memref<i32, strided<[], offset: ?>> to
1284      memref<?x?xi32, strided<[?, ?], offset: ?>>
1285
1286  %base, %base_offset, %sizes:2, %strides:2 =
1287    memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
1288    -> memref<i32>, index,
1289       index, index,
1290       index, index
1291
1292  return %base, %base_offset,
1293    %sizes#0, %sizes#1,
1294    %strides#0, %strides#1 :
1295      memref<i32>, index,
1296      index, index,
1297      index, index
1298}
1299
1300// -----
1301
1302// Check that for `memref.get_global` -> `memref.extract_strided_metadata` resolves
1303// with the consumer replaced with the strides, sizes and offsets computed from
1304// `memref.get_global`. Since the result of `memref.get_global is always static shaped
1305// no need to check for dynamic shapes.
1306
1307// CHECK-LABEL: func @extract_strided_metadata_of_get_global()
1308//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
1309//   CHECK-DAG:   %[[C384:.+]] = arith.constant 384 : index
1310//   CHECK-DAG:   %[[C512:.+]] = arith.constant 512 : index
1311//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
1312//       CHECK:   %[[GET_GLOBAL:.+]] = memref.get_global @const_i32
1313//       CHECK:   %[[CAST:.+]] = memref.reinterpret_cast %[[GET_GLOBAL]]
1314//  CHECK-SAME:       offset: [0], sizes: [], strides: []
1315//       CHECK:   return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
1316
1317memref.global "private" constant @const_i32 : memref<512x384xi32> = dense<42>
1318
1319func.func @extract_strided_metadata_of_get_global()
1320    -> (memref<i32>, index, index, index, index, index) {
1321
1322  %A = memref.get_global @const_i32 : memref<512x384xi32>
1323
1324  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %A :
1325    memref<512x384xi32> -> memref<i32>, index, index, index, index, index
1326
1327  return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
1328      memref<i32>, index, index, index, index, index
1329}
1330
1331// -----
1332
1333// Check that for `memref.get_global` -> `memref.extract_strided_metadata` does not
1334// resolve when the strides are not identity. This is an unhandled case that could
1335// be covered in the future
1336
1337// CHECK-LABEL: func @extract_strided_metadata_of_get_global_with_strides()
1338//       CHECK:   %[[GET_GLOBAL:.+]] = memref.get_global @const_i32
1339//       CHECK:   memref.extract_strided_metadata %[[GET_GLOBAL]]
1340memref.global "private" constant @const_i32 : memref<512x384xi32, strided<[420, 1], offset: 0>> = dense<42>
1341
1342func.func @extract_strided_metadata_of_get_global_with_strides()
1343    -> (memref<i32>, index, index, index, index, index) {
1344
1345  %A = memref.get_global @const_i32 : memref<512x384xi32, strided<[420, 1], offset: 0>>
1346
1347  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %A :
1348    memref<512x384xi32, strided<[420, 1], offset: 0>>
1349    -> memref<i32>, index, index, index, index, index
1350
1351  return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
1352      memref<i32>, index, index, index, index, index
1353}
1354
1355// -----
1356
1357// Check that for `memref.get_global` -> `memref.extract_strided_metadata` does not
1358// resolve when the offset is non-zero. This is an unhandled case that could
1359// be covered in the future
1360
1361// CHECK-LABEL: func @extract_strided_metadata_of_get_global_with_offset()
1362//       CHECK:   %[[GET_GLOBAL:.+]] = memref.get_global @const_i32
1363//       CHECK:   memref.extract_strided_metadata %[[GET_GLOBAL]]
1364memref.global "private" constant @const_i32 : memref<512x384xi32, strided<[384, 1], offset: 20>> = dense<42>
1365
1366func.func @extract_strided_metadata_of_get_global_with_offset()
1367    -> (memref<i32>, index, index, index, index, index) {
1368
1369  %A = memref.get_global @const_i32 : memref<512x384xi32, strided<[384, 1], offset: 20>>
1370
1371  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %A :
1372    memref<512x384xi32, strided<[384, 1], offset: 20>>
1373    -> memref<i32>, index, index, index, index, index
1374
1375  return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
1376      memref<i32>, index, index, index, index, index
1377}
1378
1379// -----
1380
1381// Check that we simplify extract_strided_metadata of cast
1382// when the source of the cast is compatible with what
1383// `extract_strided_metadata`s accept.
1384//
1385// When we apply the transformation the resulting offset, sizes and strides
1386// should come straight from the inputs of the cast.
1387// Additionally the folder on extract_strided_metadata should propagate the
1388// static information.
1389//
1390// CHECK-LABEL: func @extract_strided_metadata_of_cast
1391//  CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
1392//
1393//   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
1394//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
1395//       CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
1396//
1397//       CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
1398func.func @extract_strided_metadata_of_cast(
1399  %arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
1400  -> (memref<i32>, index,
1401      index, index,
1402      index, index) {
1403
1404  %cast =
1405    memref.cast %arg :
1406      memref<3x?xi32, strided<[4, ?], offset: ?>> to
1407      memref<?x?xi32, strided<[?, ?], offset: ?>>
1408
1409  %base, %base_offset, %sizes:2, %strides:2 =
1410    memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
1411    -> memref<i32>, index,
1412       index, index,
1413       index, index
1414
1415  return %base, %base_offset,
1416    %sizes#0, %sizes#1,
1417    %strides#0, %strides#1 :
1418      memref<i32>, index,
1419      index, index,
1420      index, index
1421}
1422
1423// -----
1424
1425// Check that we simplify extract_strided_metadata of cast
1426// when the source of the cast is compatible with what
1427// `extract_strided_metadata`s accept.
1428//
1429// Same as extract_strided_metadata_of_cast but with constant sizes and strides
1430// in the destination type.
1431//
1432// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
1433//  CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
1434//
1435//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
1436//   CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
1437//   CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
1438//       CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
1439//
1440//       CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
1441func.func @extract_strided_metadata_of_cast_w_csts(
1442  %arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
1443  -> (memref<i32>, index,
1444      index, index,
1445      index, index) {
1446
1447  %cast =
1448    memref.cast %arg :
1449      memref<?x?xi32, strided<[?, ?], offset: ?>> to
1450      memref<4x?xi32, strided<[?, 18], offset: 25>>
1451
1452  %base, %base_offset, %sizes:2, %strides:2 =
1453    memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
1454    -> memref<i32>, index,
1455       index, index,
1456       index, index
1457
1458  return %base, %base_offset,
1459    %sizes#0, %sizes#1,
1460    %strides#0, %strides#1 :
1461      memref<i32>, index,
1462      index, index,
1463      index, index
1464}
1465
1466// -----
1467
1468// Check that we don't simplify extract_strided_metadata of
1469// cast when the source of the cast is unranked.
1470// Unranked memrefs cannot feed into extract_strided_metadata operations.
1471// Note: Technically we could still fold the sizes and strides.
1472//
1473// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
1474//  CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
1475//
1476//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
1477//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
1478//
1479//       CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
1480func.func @extract_strided_metadata_of_cast_unranked(
1481  %arg : memref<*xi32>)
1482  -> (memref<i32>, index,
1483      index, index,
1484      index, index) {
1485
1486  %cast =
1487    memref.cast %arg :
1488      memref<*xi32> to
1489      memref<?x?xi32, strided<[?, ?], offset: ?>>
1490
1491  %base, %base_offset, %sizes:2, %strides:2 =
1492    memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
1493    -> memref<i32>, index,
1494       index, index,
1495       index, index
1496
1497  return %base, %base_offset,
1498    %sizes#0, %sizes#1,
1499    %strides#0, %strides#1 :
1500      memref<i32>, index,
1501      index, index,
1502      index, index
1503}
1504
1505
1506// -----
1507
1508memref.global "private" @dynamicShmem : memref<0xf16,3>
1509
1510// CHECK-LABEL: func @zero_sized_memred
1511func.func @zero_sized_memred(%arg0: f32) -> (memref<f16, 3>, index,index,index) {
1512  %c0 = arith.constant 0 : index
1513  %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
1514
1515  // CHECK: %[[BASE:.*]] = memref.get_global @dynamicShmem : memref<0xf16, 3>
1516  // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [], strides: [] : memref<0xf16, 3> to memref<f16, 3>
1517  // CHECK: return %[[CAST]]
1518
1519  %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %dynamicMem : memref<0xf16, 3> -> memref<f16, 3>, index, index, index
1520  return %base_buffer, %offset,
1521    %sizes, %strides :
1522      memref<f16,3>, index,
1523      index, index
1524}
1525
1526// -----
1527
1528func.func @extract_strided_metadata_of_collapse_shape(%base: memref<5x4xf32>)
1529    -> (memref<f32>, index, index, index) {
1530
1531  %collapse = memref.collapse_shape %base[[0, 1]] :
1532    memref<5x4xf32> into memref<20xf32>
1533
1534  %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %collapse :
1535    memref<20xf32> -> memref<f32>, index, index, index
1536
1537  return %base_buffer, %offset, %size, %stride :
1538    memref<f32>, index, index, index
1539}
1540
1541// CHECK-LABEL:  func @extract_strided_metadata_of_collapse_shape
1542//   CHECK-DAG:    %[[OFFSET:.*]] = arith.constant 0 : index
1543//   CHECK-DAG:    %[[SIZE:.*]] = arith.constant 20 : index
1544//   CHECK-DAG:    %[[STEP:.*]] = arith.constant 1 : index
1545//       CHECK:    %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
1546//       CHECK:    return %[[BASE]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32>, index, index, index
1547
1548// -----
1549
1550func.func @extract_strided_metadata_of_memory_space_cast(%base: memref<20xf32>)
1551    -> (memref<f32, 1>, index, index, index) {
1552
1553  %memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1>
1554
1555  %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast :
1556    memref<20xf32, 1> -> memref<f32, 1>, index, index, index
1557
1558  return %base_buffer, %offset, %size, %stride :
1559    memref<f32, 1>, index, index, index
1560}
1561
1562// CHECK-LABEL:  func @extract_strided_metadata_of_memory_space_cast
1563//   CHECK-DAG:    %[[OFFSET:.*]] = arith.constant 0 : index
1564//   CHECK-DAG:    %[[SIZE:.*]] = arith.constant 20 : index
1565//   CHECK-DAG:    %[[STEP:.*]] = arith.constant 1 : index
1566//       CHECK:    %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
1567//       CHECK:    %[[CAST:.*]] = memref.memory_space_cast %[[BASE]]
1568//       CHECK:    return %[[CAST]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32, 1>, index, index, index
1569
1570// -----
1571
1572func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<20xf32>)
1573    -> (index, index, index) {
1574
1575  %memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1>
1576
1577  %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast :
1578    memref<20xf32, 1> -> memref<f32, 1>, index, index, index
1579
1580  return %offset, %size, %stride : index, index, index
1581}
1582
1583// CHECK-LABEL:  func @extract_strided_metadata_of_memory_space_cast_no_base
1584//   CHECK-NOT:  memref.memory_space_cast
1585