xref: /llvm-project/mlir/test/Dialect/MemRef/normalize-memrefs.mlir (revision 2ec27848c00cda734697619047e640eadb254555)
1// RUN: mlir-opt -normalize-memrefs -allow-unregistered-dialect %s | FileCheck %s
2
3// This file tests whether the memref type having non-trivial map layouts
4// are normalized to trivial (identity) layouts.
5
6// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>
7// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 mod 2 + (d1 mod 2) * 2 + (d0 floordiv 2) * 8 + (d1 floordiv 2) * 4)>
8// CHECK-DAG: #[[$REDUCE_MAP3:.*]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>
9
10// CHECK-LABEL: func @permute()
11func.func @permute() {
12  %A = memref.alloc() : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>>
13  affine.for %i = 0 to 64 {
14    affine.for %j = 0 to 256 {
15      %1 = affine.load %A[%i, %j] : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>>
16      "prevent.dce"(%1) : (f32) -> ()
17    }
18  }
19  memref.dealloc %A : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>>
20  return
21}
22// The old memref alloc should disappear.
23// CHECK-NOT:  memref<64x256xf32>
24// CHECK:      [[MEM:%[0-9a-zA-Z_]+]] = memref.alloc() : memref<256x64xf32>
25// CHECK-NEXT: affine.for %[[I:arg[0-9a-zA-Z_]+]] = 0 to 64 {
26// CHECK-NEXT:   affine.for %[[J:arg[0-9a-zA-Z_]+]] = 0 to 256 {
27// CHECK-NEXT:     affine.load [[MEM]][%[[J]], %[[I]]] : memref<256x64xf32>
28// CHECK-NEXT:     "prevent.dce"
29// CHECK-NEXT:   }
30// CHECK-NEXT: }
31// CHECK-NEXT: memref.dealloc [[MEM]]
32// CHECK-NEXT: return
33
34// CHECK-LABEL: func @alloca
35func.func @alloca(%idx : index) {
36  // CHECK-NEXT: memref.alloca() : memref<65xf32>
37  %A = memref.alloca() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
38  // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
39  affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
40  affine.for %i = 0 to 64 {
41    %1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
42    "prevent.dce"(%1) : (f32) -> ()
43    // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
44  }
45  return
46}
47
48// CHECK-LABEL: func @shift
49func.func @shift(%idx : index) {
50  // CHECK-NEXT: memref.alloc() : memref<65xf32>
51  %A = memref.alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
52  // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
53  affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
54  affine.for %i = 0 to 64 {
55    %1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
56    "prevent.dce"(%1) : (f32) -> ()
57    // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
58  }
59  return
60}
61
62// CHECK-LABEL: func @high_dim_permute()
63func.func @high_dim_permute() {
64  // CHECK-NOT: memref<64x128x256xf32,
65  %A = memref.alloc() : memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>>
66  // CHECK: %[[I:arg[0-9a-zA-Z_]+]]
67  affine.for %i = 0 to 64 {
68    // CHECK: %[[J:arg[0-9a-zA-Z_]+]]
69    affine.for %j = 0 to 128 {
70      // CHECK: %[[K:arg[0-9a-zA-Z_]+]]
71      affine.for %k = 0 to 256 {
72        %1 = affine.load %A[%i, %j, %k] : memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>>
73        // CHECK: %{{.*}} = affine.load %{{.*}}[%[[K]], %[[I]], %[[J]]] : memref<256x64x128xf32>
74        "prevent.dce"(%1) : (f32) -> ()
75      }
76    }
77  }
78  return
79}
80
81// CHECK-LABEL: func @invalid_map
82func.func @invalid_map() {
83  %A = memref.alloc() : memref<64x128xf32, affine_map<(d0, d1) -> (d0, -d1 - 10)>>
84  // CHECK: %{{.*}} = memref.alloc() : memref<64x128xf32,
85  return
86}
87
88// A tiled layout.
89// CHECK-LABEL: func @data_tiling
90func.func @data_tiling(%idx : index) {
91  // CHECK: memref.alloc() : memref<8x32x8x16xf32>
92  %A = memref.alloc() : memref<64x512xf32, affine_map<(d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>>
93  // CHECK: affine.load %{{.*}}[symbol(%arg0) floordiv 8, symbol(%arg0) floordiv 16, symbol(%arg0) mod 8, symbol(%arg0) mod 16]
94  %1 = affine.load %A[%idx, %idx] : memref<64x512xf32, affine_map<(d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>>
95  "prevent.dce"(%1) : (f32) -> ()
96  return
97}
98
99// Strides 2 and 4 along respective dimensions.
100// CHECK-LABEL: func @strided
101func.func @strided() {
102  %A = memref.alloc() : memref<64x128xf32, affine_map<(d0, d1) -> (2*d0, 4*d1)>>
103  // CHECK: affine.for %[[IV0:.*]] =
104  affine.for %i = 0 to 64 {
105    // CHECK: affine.for %[[IV1:.*]] =
106    affine.for %j = 0 to 128 {
107      // CHECK: affine.load %{{.*}}[%[[IV0]] * 2, %[[IV1]] * 4] : memref<127x509xf32>
108      %1 = affine.load %A[%i, %j] : memref<64x128xf32, affine_map<(d0, d1) -> (2*d0, 4*d1)>>
109      "prevent.dce"(%1) : (f32) -> ()
110    }
111  }
112  return
113}
114
115// Strided, but the strides are in the linearized space.
116// CHECK-LABEL: func @strided_cumulative
117func.func @strided_cumulative() {
118  %A = memref.alloc() : memref<2x5xf32, affine_map<(d0, d1) -> (3*d0 + 17*d1)>>
119  // CHECK: affine.for %[[IV0:.*]] =
120  affine.for %i = 0 to 2 {
121    // CHECK: affine.for %[[IV1:.*]] =
122    affine.for %j = 0 to 5 {
123      // CHECK: affine.load %{{.*}}[%[[IV0]] * 3 + %[[IV1]] * 17] : memref<72xf32>
124      %1 = affine.load %A[%i, %j]  : memref<2x5xf32, affine_map<(d0, d1) -> (3*d0 + 17*d1)>>
125      "prevent.dce"(%1) : (f32) -> ()
126    }
127  }
128  return
129}
130
131// Symbolic operand for alloc, although unused. Tests replaceAllMemRefUsesWith
132// when the index remap has symbols.
133// CHECK-LABEL: func @symbolic_operands
134func.func @symbolic_operands(%s : index) {
135  // CHECK: memref.alloc() : memref<100xf32>
136  %A = memref.alloc()[%s] : memref<10x10xf32, affine_map<(d0,d1)[s0] -> (10*d0 + d1)>>
137  affine.for %i = 0 to 10 {
138    affine.for %j = 0 to 10 {
139      // CHECK: affine.load %{{.*}}[%{{.*}} * 10 + %{{.*}}] : memref<100xf32>
140      %1 = affine.load %A[%i, %j] : memref<10x10xf32, affine_map<(d0,d1)[s0] -> (10*d0 + d1)>>
141      "prevent.dce"(%1) : (f32) -> ()
142    }
143  }
144  return
145}
146
147// Semi-affine maps, normalization not implemented yet.
148// CHECK-LABEL: func @semi_affine_layout_map
149func.func @semi_affine_layout_map(%s0: index, %s1: index) {
150  %A = memref.alloc()[%s0, %s1] : memref<256x1024xf32, affine_map<(d0, d1)[s0, s1] -> (d0*s0 + d1*s1)>>
151  affine.for %i = 0 to 256 {
152    affine.for %j = 0 to 1024 {
153      // CHECK: memref<256x1024xf32, #map{{[0-9a-zA-Z_]+}}>
154      affine.load %A[%i, %j] : memref<256x1024xf32, affine_map<(d0, d1)[s0, s1] -> (d0*s0 + d1*s1)>>
155    }
156  }
157  return
158}
159
160// CHECK-LABEL: func @alignment
161func.func @alignment() {
162  %A = memref.alloc() {alignment = 32 : i64}: memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>>
163  // CHECK-NEXT: memref.alloc() {alignment = 32 : i64} : memref<256x64x128xf32>
164  return
165}
166
167#tile = affine_map < (i)->(i floordiv 4, i mod 4) >
168
169// Following test cases check the inter-procedural memref normalization.
170
171// Test case 1: Check normalization for multiple memrefs in a function argument list.
172// CHECK-LABEL: func @multiple_argument_type
173// CHECK-SAME:  (%[[A:arg[0-9a-zA-Z_]+]]: memref<4x4xf64>, %[[B:arg[0-9a-zA-Z_]+]]: f64, %[[C:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>, %[[D:arg[0-9a-zA-Z_]+]]: memref<24xf64>) -> f64
174func.func @multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>, %D: memref<24xf64>) -> f64 {
175  %a = affine.load %A[0] : memref<16xf64, #tile>
176  %p = arith.mulf %a, %a : f64
177  affine.store %p, %A[10] : memref<16xf64, #tile>
178  call @single_argument_type(%C): (memref<8xf64, #tile>) -> ()
179  return %B : f64
180}
181
182// CHECK: %[[a:[0-9a-zA-Z_]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64>
183// CHECK: %[[p:[0-9a-zA-Z_]+]] = arith.mulf %[[a]], %[[a]] : f64
184// CHECK: affine.store %[[p]], %[[A]][2, 2] : memref<4x4xf64>
185// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> ()
186// CHECK: return %[[B]] : f64
187
188// Test case 2: Check normalization for single memref argument in a function.
189// CHECK-LABEL: func @single_argument_type
190// CHECK-SAME: (%[[C:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>)
191func.func @single_argument_type(%C : memref<8xf64, #tile>) {
192  %a = memref.alloc(): memref<8xf64, #tile>
193  %b = memref.alloc(): memref<16xf64, #tile>
194  %d = arith.constant 23.0 : f64
195  %e = memref.alloc(): memref<24xf64>
196  call @single_argument_type(%a): (memref<8xf64, #tile>) -> ()
197  call @single_argument_type(%C): (memref<8xf64, #tile>) -> ()
198  call @multiple_argument_type(%b, %d, %a, %e): (memref<16xf64, #tile>, f64, memref<8xf64, #tile>, memref<24xf64>) -> f64
199  return
200}
201
202// CHECK: %[[a:[0-9a-zA-Z_]+]] = memref.alloc() : memref<2x4xf64>
203// CHECK: %[[b:[0-9a-zA-Z_]+]] = memref.alloc() : memref<4x4xf64>
204// CHECK: %cst = arith.constant 2.300000e+01 : f64
205// CHECK: %[[e:[0-9a-zA-Z_]+]] = memref.alloc() : memref<24xf64>
206// CHECK: call @single_argument_type(%[[a]]) : (memref<2x4xf64>) -> ()
207// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> ()
208// CHECK: call @multiple_argument_type(%[[b]], %cst, %[[a]], %[[e]]) : (memref<4x4xf64>, f64, memref<2x4xf64>, memref<24xf64>) -> f64
209
210// Test case 3: Check function returning any other type except memref.
211// CHECK-LABEL: func @non_memref_ret
212// CHECK-SAME: (%[[C:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>) -> i1
213func.func @non_memref_ret(%A: memref<8xf64, #tile>) -> i1 {
214  %d = arith.constant 1 : i1
215  return %d : i1
216}
217
218// Test cases here onwards deal with normalization of memref in function signature, caller site.
219
220// Test case 4: Check successful memref normalization in case of inter/intra-recursive calls.
221// CHECK-LABEL: func @ret_multiple_argument_type
222// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<4x4xf64>, %[[B:arg[0-9a-zA-Z_]+]]: f64, %[[C:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>) -> (memref<2x4xf64>, f64)
223func.func @ret_multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64) {
224  %a = affine.load %A[0] : memref<16xf64, #tile>
225  %p = arith.mulf %a, %a : f64
226  %cond = arith.constant 1 : i1
227  cf.cond_br %cond, ^bb1, ^bb2
228  ^bb1:
229    %res1, %res2 = call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
230    return %res2, %p: memref<8xf64, #tile>, f64
231  ^bb2:
232    return %C, %p: memref<8xf64, #tile>, f64
233}
234
235// CHECK:   %[[a:[0-9a-zA-Z_]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64>
236// CHECK:   %[[p:[0-9a-zA-Z_]+]] = arith.mulf %[[a]], %[[a]] : f64
237// CHECK:   %true = arith.constant true
238// CHECK:   cf.cond_br %true, ^bb1, ^bb2
239// CHECK: ^bb1:  // pred: ^bb0
240// CHECK:   %[[res:[0-9a-zA-Z_]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
241// CHECK:   return %[[res]]#1, %[[p]] : memref<2x4xf64>, f64
242// CHECK: ^bb2:  // pred: ^bb0
243// CHECK:   return %{{.*}}, %{{.*}} : memref<2x4xf64>, f64
244
245// CHECK-LABEL: func @ret_single_argument_type
246// CHECK-SAME: (%[[C:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
247func.func @ret_single_argument_type(%C: memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>){
248  %a = memref.alloc() : memref<8xf64, #tile>
249  %b = memref.alloc() : memref<16xf64, #tile>
250  %d = arith.constant 23.0 : f64
251  call @ret_single_argument_type(%a) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
252  call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
253  %res1, %res2 = call @ret_multiple_argument_type(%b, %d, %a) : (memref<16xf64, #tile>, f64, memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64)
254  %res3, %res4 = call @ret_single_argument_type(%res1) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
255  return %b, %a: memref<16xf64, #tile>, memref<8xf64, #tile>
256}
257
258// CHECK: %[[a:[0-9a-zA-Z_]+]] = memref.alloc() : memref<2x4xf64>
259// CHECK: %[[b:[0-9a-zA-Z_]+]] = memref.alloc() : memref<4x4xf64>
260// CHECK: %cst = arith.constant 2.300000e+01 : f64
261// CHECK: %[[resA:[0-9a-zA-Z_]+]]:2 = call @ret_single_argument_type(%[[a]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
262// CHECK: %[[resB:[0-9a-zA-Z_]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
263// CHECK: %[[resC:[0-9a-zA-Z_]+]]:2 = call @ret_multiple_argument_type(%[[b]], %cst, %[[a]]) : (memref<4x4xf64>, f64, memref<2x4xf64>) -> (memref<2x4xf64>, f64)
264// CHECK: %[[resD:[0-9a-zA-Z_]+]]:2 = call @ret_single_argument_type(%[[resC]]#0) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
265// CHECK: return %{{.*}}, %{{.*}} : memref<4x4xf64>, memref<2x4xf64>
266
267// Test case set #5: To check normalization in a chain of interconnected functions.
268// CHECK-LABEL: func @func_A
269// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>)
270func.func @func_A(%A: memref<8xf64, #tile>) {
271  call @func_B(%A) : (memref<8xf64, #tile>) -> ()
272  return
273}
274// CHECK: call @func_B(%[[A]]) : (memref<2x4xf64>) -> ()
275
276// CHECK-LABEL: func @func_B
277// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>)
278func.func @func_B(%A: memref<8xf64, #tile>) {
279  call @func_C(%A) : (memref<8xf64, #tile>) -> ()
280  return
281}
282// CHECK: call @func_C(%[[A]]) : (memref<2x4xf64>) -> ()
283
284// CHECK-LABEL: func @func_C
285// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<2x4xf64>)
286func.func @func_C(%A: memref<8xf64, #tile>) {
287  return
288}
289
290// Test case set #6: Checking if no normalization takes place in a scenario: A -> B -> C and B has an unsupported type.
291// CHECK-LABEL: func @some_func_A
292// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<8xf64, #map{{[0-9a-zA-Z_]+}}>)
293func.func @some_func_A(%A: memref<8xf64, #tile>) {
294  call @some_func_B(%A) : (memref<8xf64, #tile>) -> ()
295  return
296}
297// CHECK: call @some_func_B(%[[A]]) : (memref<8xf64, #map{{[0-9a-zA-Z_]+}}>) -> ()
298
299// CHECK-LABEL: func @some_func_B
300// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<8xf64, #map{{[0-9a-zA-Z_]+}}>)
301func.func @some_func_B(%A: memref<8xf64, #tile>) {
302  "test.test"(%A) : (memref<8xf64, #tile>) -> ()
303  call @some_func_C(%A) : (memref<8xf64, #tile>) -> ()
304  return
305}
306// CHECK: call @some_func_C(%[[A]]) : (memref<8xf64, #map{{[0-9a-zA-Z_]+}}>) -> ()
307
308// CHECK-LABEL: func @some_func_C
309// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<8xf64, #map{{[0-9a-zA-Z_]+}}>)
310func.func @some_func_C(%A: memref<8xf64, #tile>) {
311  return
312}
313
314// Test case set #7: Check normalization in case of external functions.
315// CHECK-LABEL: func private @external_func_A
316// CHECK-SAME: (memref<4x4xf64>)
317func.func private @external_func_A(memref<16xf64, #tile>) -> ()
318
319// CHECK-LABEL: func private @external_func_B
320// CHECK-SAME: (memref<4x4xf64>, f64) -> memref<2x4xf64>
321func.func private @external_func_B(memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>)
322
323// CHECK-LABEL: func @simply_call_external()
324func.func @simply_call_external() {
325  %a = memref.alloc() : memref<16xf64, #tile>
326  call @external_func_A(%a) : (memref<16xf64, #tile>) -> ()
327  return
328}
329// CHECK: %[[a:[0-9a-zA-Z_]+]] = memref.alloc() : memref<4x4xf64>
330// CHECK: call @external_func_A(%[[a]]) : (memref<4x4xf64>) -> ()
331
332// CHECK-LABEL: func @use_value_of_external
333// CHECK-SAME: (%[[A:arg[0-9a-zA-Z_]+]]: memref<4x4xf64>, %[[B:arg[0-9a-zA-Z_]+]]: f64) -> memref<2x4xf64>
334func.func @use_value_of_external(%A: memref<16xf64, #tile>, %B: f64) -> (memref<8xf64, #tile>) {
335  %res = call @external_func_B(%A, %B) : (memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>)
336  return %res : memref<8xf64, #tile>
337}
338// CHECK: %[[res:[0-9a-zA-Z_]+]] = call @external_func_B(%[[A]], %[[B]]) : (memref<4x4xf64>, f64) -> memref<2x4xf64>
339// CHECK: return %{{.*}} : memref<2x4xf64>
340
341// CHECK-LABEL: func @affine_parallel_norm
342func.func @affine_parallel_norm() ->  memref<8xf32, #tile> {
343  %c = arith.constant 23.0 : f32
344  %a = memref.alloc() : memref<8xf32, #tile>
345  // CHECK: affine.parallel (%{{.*}}) = (0) to (8) reduce ("assign") -> (memref<2x4xf32>)
346  %1 = affine.parallel (%i) = (0) to (8) reduce ("assign") ->  memref<8xf32, #tile> {
347    affine.store %c, %a[%i] : memref<8xf32, #tile>
348    // CHECK: affine.yield %{{.*}} : memref<2x4xf32>
349    affine.yield %a : memref<8xf32, #tile>
350  }
351  return %1 : memref<8xf32, #tile>
352}
353
354#map = affine_map<(d0, d1)[s0] -> (d0 * 3 + s0 + d1)>
355// CHECK-LABEL: func.func @map_symbol
356func.func @map_symbol() -> memref<2x3xf32, #map> {
357  %c1 = arith.constant 1 : index
358  // The constant isn't propagated here and the utility can't compute a constant
359  // upper bound for the memref dimension in the absence of that.
360  // CHECK: memref.alloc()[%{{.*}}]
361  %0 = memref.alloc()[%c1] : memref<2x3xf32, #map>
362  return %0 : memref<2x3xf32, #map>
363}
364
365#neg = affine_map<(d0, d1) -> (d0, d1 - 100)>
366// CHECK-LABEL: func.func @neg_map
367func.func @neg_map() -> memref<2x3xf32, #neg> {
368  // This isn't a valid map for normalization.
369  // CHECK: memref.alloc() : memref<2x3xf32, #{{.*}}>
370  %0 = memref.alloc() : memref<2x3xf32, #neg>
371  return %0 : memref<2x3xf32, #neg>
372}
373
374// CHECK-LABEL: func @memref_with_strided_offset
375func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index, %arg2: index) -> tensor<16x512xf32> {
376  %c0 = arith.constant 0 : index
377  %0 = bufferization.to_memref %arg0 : tensor<128x512xf32> to memref<128x512xf32, strided<[?, ?], offset: ?>>
378  %subview = memref.subview %0[%arg2, 0] [%arg1, 512] [1, 1] : memref<128x512xf32, strided<[?, ?], offset: ?>> to memref<?x512xf32, strided<[?, ?], offset: ?>>
379  // CHECK: %{{.*}} = memref.cast %{{.*}} : memref<?x512xf32, strided<[?, ?], offset: ?>> to memref<16x512xf32, strided<[?, ?], offset: ?>>
380  %cast = memref.cast %subview : memref<?x512xf32, strided<[?, ?], offset: ?>> to memref<16x512xf32, strided<[?, ?], offset: ?>>
381  %1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>> to tensor<16x512xf32>
382  return %1 : tensor<16x512xf32>
383}
384
385#map0 = affine_map<(i,k) -> (2 * (i mod 2) + (k mod 2) + 4 * (i floordiv 2) + 8 * (k floordiv 2))>
386#map1 = affine_map<(k,j) -> ((k mod 2) + 2 * (j mod 2) + 8 * (k floordiv 2) + 4 * (j floordiv 2))>
387#map2 = affine_map<(i,j) -> (4 * i + j)>
388// CHECK-LABEL: func @memref_load_with_reduction_map
389func.func @memref_load_with_reduction_map(%arg0 :  memref<4x4xf32,#map2>) -> () {
390  %0 = memref.alloc() : memref<4x8xf32,#map0>
391  %1 = memref.alloc() : memref<8x4xf32,#map1>
392  %2 = memref.alloc() : memref<4x4xf32,#map2>
393  // CHECK-NOT:  memref<4x8xf32>
394  // CHECK-NOT:  memref<8x4xf32>
395  // CHECK-NOT:  memref<4x4xf32>
396  %cst = arith.constant 3.0 : f32
397  %cst0 = arith.constant 0 : index
398  affine.for %i = 0 to 4 {
399    affine.for %j = 0 to 8 {
400      affine.for %k = 0 to 8 {
401        // CHECK: %[[INDEX0:.*]] = affine.apply #[[$REDUCE_MAP1]](%{{.*}}, %{{.*}})
402        // CHECK: memref.load %alloc[%[[INDEX0]]] : memref<32xf32>
403        %a = memref.load %0[%i, %k] : memref<4x8xf32,#map0>
404        // CHECK: %[[INDEX1:.*]] = affine.apply #[[$REDUCE_MAP2]](%{{.*}}, %{{.*}})
405        // CHECK: memref.load %alloc_0[%[[INDEX1]]] : memref<32xf32>
406        %b = memref.load %1[%k, %j] :memref<8x4xf32,#map1>
407        // CHECK: %[[INDEX2:.*]] = affine.apply #[[$REDUCE_MAP3]](%{{.*}}, %{{.*}})
408        // CHECK: memref.load %alloc_1[%[[INDEX2]]] : memref<16xf32>
409        %c = memref.load %2[%i, %j] : memref<4x4xf32,#map2>
410        %3 = arith.mulf %a, %b : f32
411        %4 = arith.addf %3, %c : f32
412        affine.store %4, %arg0[%i, %j] : memref<4x4xf32,#map2>
413      }
414    }
415  }
416  return
417}