xref: /llvm-project/mlir/test/Dialect/MemRef/extract-address-computations.mlir (revision 2798b72ae7e5caad793169b77cbac47fe2362d0f)
1// RUN: mlir-opt -transform-interpreter %s --split-input-file --verify-diagnostics | FileCheck %s
2
3// Simple test: check that we extract the address computation of a load into
4// a dedicated subview.
5// The resulting load will be loading from the subview and have only indices
6// set to zero.
7
8// CHECK-LABEL: @test_load(
9// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}},
10// CHECK-SAME: %[[DYN_OFFSET:.*]]: index)
11// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
12// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
13// CHECK: %[[LOADED_VAL:.*]] = memref.load %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
14// CHECK: return %[[LOADED_VAL]] : f32
15
16// expected-remark @below {{transformed}}
17func.func @test_load(%base : memref<2x16x16xf32>, %offset : index) -> f32 {
18  %c0 = arith.constant 0 : index
19  %c8 = arith.constant 8 : index
20  %loaded_val = memref.load %base[%offset, %c0, %c8] : memref<2x16x16xf32>
21  return %loaded_val : f32
22}
23
24module attributes {transform.with_named_sequence} {
25  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
26    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
27    transform.apply_patterns to %0 {
28      transform.apply_patterns.memref.extract_address_computations
29    } : !transform.any_op
30    // Verify that the returned handle is usable.
31    transform.debug.emit_remark_at %0, "transformed" : !transform.any_op
32    transform.yield
33  }
34}
35
36// -----
37
38// Same as previous @test_load but with the nontemporal flag.
39
40// CHECK-LABEL: @test_load_nontemporal(
41// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}},
42// CHECK-SAME: %[[DYN_OFFSET:.*]]: index)
43// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
44// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
45// CHECK: %[[LOADED_VAL:.*]] = memref.load %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {nontemporal = true} : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
46// CHECK: return %[[LOADED_VAL]] : f32
47func.func @test_load_nontemporal(%base : memref<2x16x16xf32>, %offset : index) -> f32 {
48  %c0 = arith.constant 0 : index
49  %c8 = arith.constant 8 : index
50  %loaded_val = memref.load %base[%offset, %c0, %c8] {nontemporal = true } : memref<2x16x16xf32>
51  return %loaded_val : f32
52}
53
54module attributes {transform.with_named_sequence} {
55  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
56    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
57    transform.apply_patterns to %0 {
58      transform.apply_patterns.memref.extract_address_computations
59    } : !transform.any_op
60    transform.yield
61  }
62}
63
64// -----
65
66// Simple test: check that we extract the address computation of a store into
67// a dedicated subview.
68// The resulting store will use the address from the subview and have only
69// indices set to zero.
70
71// CHECK-LABEL: @test_store(
72// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}},
73// CHECK-SAME: %[[DYN_OFFSET:.*]]: index)
74// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f32
75// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
76// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
77// CHECK: memref.store %[[CF0]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
78// CHECK: return
79func.func @test_store(%base : memref<2x16x16xf32>, %offset : index) -> () {
80  %cf0 = arith.constant 0.0 : f32
81  %c0 = arith.constant 0 : index
82  %c8 = arith.constant 8 : index
83  memref.store %cf0, %base[%offset, %c0, %c8] : memref<2x16x16xf32>
84  return
85}
86
87module attributes {transform.with_named_sequence} {
88  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
89    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
90    transform.apply_patterns to %0 {
91      transform.apply_patterns.memref.extract_address_computations
92    } : !transform.any_op
93    transform.yield
94  }
95}
96
97// -----
98
99// Same as @test_store but check that the nontemporal flag is preserved.
100
101// CHECK-LABEL: @test_store_nontemporal(
102// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}},
103// CHECK-SAME: %[[DYN_OFFSET:.*]]: index)
104// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f32
105// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
106// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
107// CHECK: memref.store %[[CF0]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {nontemporal = true} : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
108// CHECK: return
109func.func @test_store_nontemporal(%base : memref<2x16x16xf32>, %offset : index) -> () {
110  %cf0 = arith.constant 0.0 : f32
111  %c0 = arith.constant 0 : index
112  %c8 = arith.constant 8 : index
113  memref.store %cf0, %base[%offset, %c0, %c8] { nontemporal = true } : memref<2x16x16xf32>
114  return
115}
116
117module attributes {transform.with_named_sequence} {
118  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
119    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
120    transform.apply_patterns to %0 {
121      transform.apply_patterns.memref.extract_address_computations
122    } : !transform.any_op
123    transform.yield
124  }
125}
126
127// -----
128// For this test, we made the source memref fully dynamic.
129// The gist of the check remains the same as the simple test:
130// The address computation is extracted into its own subview.
131// CHECK-LABEL: @testWithLoop(
132// CHECK-SAME: %[[BASE:[^:]*]]: memref
133// CHECK:  %[[SUM_ALL:.*]] = arith.constant 0.0{{0*e\+00}} : f32
134// CHECK:  %[[C0:.*]] = arith.constant 0 : index
135// CHECK:  %[[C1:.*]] = arith.constant 1 : index
136// CHECK:  %[[C2:.*]] = arith.constant 2 : index
137// CHECK:  %[[UPPER_BOUND0:.*]] = memref.dim %[[BASE]], %[[C0]] : memref<?x?x?xf32,
138// CHECK:  %[[UPPER_BOUND1:.*]] = memref.dim %[[BASE]], %[[C1]] : memref<?x?x?xf32,
139// CHECK:  %[[UPPER_BOUND2:.*]] = memref.dim %[[BASE]], %[[C2]] : memref<?x?x?xf32,
140// CHECK:  %[[SUM_RES2:.*]] = scf.for %[[IV2:.*]] = %[[C0]] to %[[UPPER_BOUND2]] step %[[C1]] iter_args(%[[SUM_ITER2:.*]] = %[[SUM_ALL]]) -> (f32) {
141// CHECK:    %[[SUM_RES1:.*]] = scf.for %[[IV1:.*]] = %[[C0]] to %[[UPPER_BOUND1]] step %[[C1]] iter_args(%[[SUM_ITER1:.*]] = %[[SUM_ITER2]]) -> (f32) {
142// CHECK:      %[[SUM_RES0:.*]] = scf.for %[[IV0:.*]] = %[[C0]] to %[[UPPER_BOUND0]] step %[[C1]] iter_args(%[[SUM_ITER0:.*]] = %[[SUM_ITER1]]) -> (f32) {
143// CHECK:        %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[IV0]], %[[IV1]], %[[IV2]]] [1, 1, 1] [1, 1, 1] : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to memref<1x1x1xf32, strided<[?, ?, ?], offset: ?>>
144// CHECK:        %[[LOADED_VAL:.*]] = memref.load %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32, strided<[?, ?, ?], offset: ?>>
145// CHECK:        %[[RES:.*]] = arith.addf %[[LOADED_VAL]], %[[SUM_ITER2]] : f32
146// CHECK:        scf.yield %[[RES]] : f32
147// CHECK:      }
148// CHECK:      scf.yield %[[SUM_RES0]] : f32
149// CHECK:    }
150// CHECK:    scf.yield %[[SUM_RES1]] : f32
151// CHECK:  }
152// CHECK:  return %[[SUM_RES2]] : f32
153func.func @testWithLoop(%base : memref<?x?x?xf32, strided<[?,?,?], offset: ?>>) -> f32 {
154  %sum_all = arith.constant 0.0 : f32
155  %c0 = arith.constant 0 : index
156  %c1 = arith.constant 1 : index
157  %c2 = arith.constant 2 : index
158  %upper_bound0 = memref.dim %base, %c0 : memref<?x?x?xf32, strided<[?,?,?], offset: ?>>
159  %upper_bound1 = memref.dim %base, %c1 : memref<?x?x?xf32, strided<[?,?,?], offset: ?>>
160  %upper_bound2 = memref.dim %base, %c2 : memref<?x?x?xf32, strided<[?,?,?], offset: ?>>
161  %sum_res2 = scf.for %iv2 = %c0 to %upper_bound2 step %c1 iter_args(%sum_iter2 = %sum_all) -> (f32) {
162    %sum_res1 = scf.for %iv1 = %c0 to %upper_bound1 step %c1 iter_args(%sum_iter1 = %sum_iter2) -> (f32) {
163      %sum_res0 = scf.for %iv0 = %c0 to %upper_bound0 step %c1 iter_args(%sum_iter0 = %sum_iter1) -> (f32) {
164        %loaded_val = memref.load %base[%iv0, %iv1, %iv2] : memref<?x?x?xf32, strided<[?,?,?], offset: ?>>
165        %res = arith.addf %loaded_val, %sum_iter2 : f32
166        scf.yield %res : f32
167      }
168      scf.yield %sum_res0 : f32
169    }
170    scf.yield %sum_res1 : f32
171  }
172  return %sum_res2 : f32
173}
174
175module attributes {transform.with_named_sequence} {
176  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
177    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
178    transform.apply_patterns to %0 {
179      transform.apply_patterns.memref.extract_address_computations
180    } : !transform.any_op
181    transform.yield
182  }
183}
184
185// -----
186
187// Simple test: check that we extract the address computation of a ldmatrix into
188// a dedicated subview.
189// The resulting ldmatrix will loaded from with subview and have only indices set
190// to zero.
191// Also the sizes of the view are adjusted to `original size - offset`.
192
193// CHECK-DAG: #[[$FOUR_MINUS_OFF_MAP:.*]] = affine_map<()[s0] -> (-s0 + 4)>
194// CHECK-DAG: #[[$THIRTY_TWO_MINUS_OFF_MAP:.*]] = affine_map<()[s0] -> (-s0 + 32)>
195// CHECK-LABEL: @test_ldmatrix(
196// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}, 3>,
197// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index,
198// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index,
199// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index)
200// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$FOUR_MINUS_OFF_MAP]]()[%[[DYN_OFFSET0]]]
201// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$THIRTY_TWO_MINUS_OFF_MAP]]()[%[[DYN_OFFSET1]]]
202// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$THIRTY_TWO_MINUS_OFF_MAP]]()[%[[DYN_OFFSET2]]]
203// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
204// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref<4x32x32xf16, 3> to memref<?x?x?xf16, strided<[1024, 32, 1], offset: ?>, 3>
205// CHECK: %[[LOADED_VAL:.*]] = nvgpu.ldmatrix %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {numTiles = 4 : i32, transpose = false} : memref<?x?x?xf16, strided<[1024, 32, 1], offset: ?>, 3> -> vector<4x2xf16>
206// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16>
207func.func @test_ldmatrix(%base : memref<4x32x32xf16, 3>,
208    %offset0 : index, %offset1: index, %offset2: index)
209    -> vector<4x2xf16> {
210  %loaded_val = nvgpu.ldmatrix
211    %base[%offset0, %offset1, %offset2]
212    {numTiles = 4 : i32, transpose = false}
213      : memref<4x32x32xf16, 3> -> vector<4x2xf16>
214  return %loaded_val : vector<4x2xf16>
215}
216
217module attributes {transform.with_named_sequence} {
218  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
219    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
220    transform.apply_patterns to %0 {
221      transform.apply_patterns.memref.extract_address_computations
222    } : !transform.any_op
223    transform.yield
224  }
225}
226
227// -----
228
229// Same as test_ldmatrix but with fully dynamic memref.
230
231// CHECK-DAG: #[[$A_MINUS_B_MAP:.*]] = affine_map<()[s0, s1] -> (s0 - s1)>
232// CHECK-LABEL: @test_ldmatrix(
233// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}, 3>,
234// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index,
235// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index,
236// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index)
237// CHECK-DAG: {{.*}}, {{.*}}, %[[DYN_SIZES:.*]]:3, {{.*}} = memref.extract_strided_metadata %[[BASE]]
238// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#0, %[[DYN_OFFSET0]]]
239// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#1, %[[DYN_OFFSET1]]]
240// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#2, %[[DYN_OFFSET2]]]
241// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
242// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref<?x?x?xf16, 3> to memref<?x?x?xf16, strided<[?, ?, 1], offset: ?>, 3>
243// CHECK: %[[LOADED_VAL:.*]] = nvgpu.ldmatrix %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {numTiles = 4 : i32, transpose = false} : memref<?x?x?xf16, strided<[?, ?, 1], offset: ?>, 3> -> vector<4x2xf16>
244// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16>
245func.func @test_ldmatrix(%base : memref<?x?x?xf16, 3>,
246    %offset0 : index, %offset1: index, %offset2: index)
247    -> vector<4x2xf16> {
248  %loaded_val = nvgpu.ldmatrix
249    %base[%offset0, %offset1, %offset2]
250    {numTiles = 4 : i32, transpose = false}
251      : memref<?x?x?xf16, 3> -> vector<4x2xf16>
252  return %loaded_val : vector<4x2xf16>
253}
254
255module attributes {transform.with_named_sequence} {
256  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
257    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
258    transform.apply_patterns to %0 {
259      transform.apply_patterns.memref.extract_address_computations
260    } : !transform.any_op
261    transform.yield
262  }
263}
264
265// -----
266
267// Simple test for vector.transfer_read with fully dynamic memref.
268// We also set a permutation map to make sure it is properly preserved.
269
270// CHECK-DAG: #[[$A_MINUS_B_MAP:.*]] = affine_map<()[s0, s1] -> (s0 - s1)>
271// CHECK-DAG: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
272// CHECK-LABEL: @test_transfer_read_op(
273// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}>,
274// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index,
275// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index,
276// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index)
277// CHECK-DAG: {{.*}}, {{.*}}, %[[DYN_SIZES:.*]]:3, {{.*}} = memref.extract_strided_metadata %[[BASE]]
278// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#0, %[[DYN_OFFSET0]]]
279// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#1, %[[DYN_OFFSET1]]]
280// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#2, %[[DYN_OFFSET2]]]
281// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
282// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f16
283// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref<?x?x?xf16> to memref<?x?x?xf16, strided<[?, ?, 1], offset: ?>>
284// CHECK: %[[LOADED_VAL:.*]] = vector.transfer_read %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]], %[[CF0]] {permutation_map = #[[$PERMUTATION_MAP]]} : memref<?x?x?xf16, strided<[?, ?, 1], offset: ?>>, vector<4x2xf16>
285// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16>
286func.func @test_transfer_read_op(%base : memref<?x?x?xf16>,
287    %offset0 : index, %offset1: index, %offset2: index)
288    -> vector<4x2xf16> {
289  %cf0 = arith.constant 0.0 : f16
290  %loaded_val = vector.transfer_read %base[%offset0, %offset1, %offset2], %cf0 { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : memref<?x?x?xf16>, vector<4x2xf16>
291  return %loaded_val : vector<4x2xf16>
292}
293
294module attributes {transform.with_named_sequence} {
295  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
296    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
297    transform.apply_patterns to %0 {
298      transform.apply_patterns.memref.extract_address_computations
299    } : !transform.any_op
300    transform.yield
301  }
302}
303
304// -----
305
306// Same as test_transfer_read_op but with tensors.
307// Right now this rewrite is not supported but we still shouldn't choke on it.
308
309// CHECK: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
310// CHECK-LABEL: @test_transfer_read_op_with_tensor(
311// CHECK-SAME: %[[BASE:[^:]*]]: tensor<{{[^,]*}}>,
312// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index,
313// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index,
314// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index)
315// CHECK: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f16
316// CHECK: %[[LOADED_VAL:.*]] = vector.transfer_read %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]], %[[CF0]] {permutation_map = #[[$PERMUTATION_MAP]]} : tensor<?x?x?xf16>, vector<4x2xf16>
317// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16>
318func.func @test_transfer_read_op_with_tensor(%base : tensor<?x?x?xf16>,
319    %offset0 : index, %offset1: index, %offset2: index)
320    -> vector<4x2xf16> {
321  %cf0 = arith.constant 0.0 : f16
322  %loaded_val = vector.transfer_read %base[%offset0, %offset1, %offset2], %cf0 { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : tensor<?x?x?xf16>, vector<4x2xf16>
323  return %loaded_val : vector<4x2xf16>
324}
325
326module attributes {transform.with_named_sequence} {
327  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
328    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
329    transform.apply_patterns to %0 {
330      transform.apply_patterns.memref.extract_address_computations
331    } : !transform.any_op
332    transform.yield
333  }
334}
335
336// -----
337
338// Simple test for vector.transfer_write with fully dynamic memref.
339// We also set a permutation map to make sure it is properly preserved.
340
341// CHECK-DAG: #[[$A_MINUS_B_MAP:.*]] = affine_map<()[s0, s1] -> (s0 - s1)>
342// CHECK-DAG: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
343// CHECK-LABEL: @test_transfer_write_op(
344// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}>,
345// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index,
346// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index,
347// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index)
348// CHECK-DAG: {{.*}}, {{.*}}, %[[DYN_SIZES:.*]]:3, {{.*}} = memref.extract_strided_metadata %[[BASE]]
349// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#0, %[[DYN_OFFSET0]]]
350// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#1, %[[DYN_OFFSET1]]]
351// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#2, %[[DYN_OFFSET2]]]
352// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
353// CHECK-DAG: %[[VCF0:.*]] = arith.constant dense<0.0{{0*e\+00}}> : vector<4x2xf16>
354// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref<?x?x?xf16> to memref<?x?x?xf16, strided<[?, ?, 1], offset: ?>>
355// CHECK: vector.transfer_write %[[VCF0]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {permutation_map = #[[$PERMUTATION_MAP]]} : vector<4x2xf16>, memref<?x?x?xf16, strided<[?, ?, 1], offset: ?>>
356// CHECK: return
357func.func @test_transfer_write_op(%base : memref<?x?x?xf16>,
358    %offset0 : index, %offset1: index, %offset2: index) {
359  %vcf0 = arith.constant dense<0.000000e+00> : vector<4x2xf16>
360  vector.transfer_write %vcf0, %base[%offset0, %offset1, %offset2] { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : vector<4x2xf16>, memref<?x?x?xf16>
361  return
362}
363
364module attributes {transform.with_named_sequence} {
365  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
366    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
367    transform.apply_patterns to %0 {
368      transform.apply_patterns.memref.extract_address_computations
369    } : !transform.any_op
370    transform.yield
371  }
372}
373
374// -----
375
376// Check that the strides of the original memref are kept.
377// Moreover even with non-1 strides the subview should still issue [1,...]
378// strides, since this is a multiplication factor.
379
380// CHECK-DAG: #[[$A_MINUS_B_MAP:.*]] = affine_map<()[s0, s1] -> (s0 - s1)>
381// CHECK-DAG: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
382// CHECK-LABEL: @test_transfer_write_op_with_strides(
383// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^>]*}}>>,
384// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index,
385// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index,
386// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index)
387// CHECK-DAG: {{.*}}, {{.*}}, %[[DYN_SIZES:.*]]:3, {{.*}} = memref.extract_strided_metadata %[[BASE]]
388// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#0, %[[DYN_OFFSET0]]]
389// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#1, %[[DYN_OFFSET1]]]
390// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#2, %[[DYN_OFFSET2]]]
391// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
392// CHECK-DAG: %[[VCF0:.*]] = arith.constant dense<0.0{{0*e\+00}}> : vector<4x2xf16>
393// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref<?x?x?xf16, strided<[329, 26, 12], offset: ?>> to memref<?x?x?xf16, strided<[329, 26, 12], offset: ?>>
394// CHECK: vector.transfer_write %[[VCF0]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {permutation_map = #[[$PERMUTATION_MAP]]} : vector<4x2xf16>, memref<?x?x?xf16, strided<[329, 26, 12], offset: ?>>
395// CHECK: return
396func.func @test_transfer_write_op_with_strides(%base : memref<?x?x?xf16, strided<[329, 26, 12], offset: ?>>,
397    %offset0 : index, %offset1: index, %offset2: index) {
398  %vcf0 = arith.constant dense<0.000000e+00> : vector<4x2xf16>
399  vector.transfer_write %vcf0, %base[%offset0, %offset1, %offset2] { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : vector<4x2xf16>, memref<?x?x?xf16, strided<[329, 26, 12], offset: ?>>
400  return
401}
402
403module attributes {transform.with_named_sequence} {
404  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
405    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
406    transform.apply_patterns to %0 {
407      transform.apply_patterns.memref.extract_address_computations
408    } : !transform.any_op
409    transform.yield
410  }
411}
412
413// -----
414
415// Same as test_transfer_write_op but with tensors.
416// Right now this rewrite is not supported but we still shouldn't choke on it.
417
418// CHECK: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
419// CHECK-LABEL: @test_transfer_write_op_with_tensor(
420// CHECK-SAME: %[[BASE:[^:]*]]: tensor<{{[^,]*}}>,
421// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index,
422// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index,
423// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index)
424// CHECK-DAG: %[[VCF0:.*]] = arith.constant dense<0.0{{0*e\+00}}> : vector<4x2xf16>
425// CHECK: %[[RES:.*]] = vector.transfer_write %[[VCF0]], %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] {permutation_map = #[[$PERMUTATION_MAP]]} : vector<4x2xf16>, tensor<?x?x?xf16>
426// CHECK: return %[[RES]] : tensor<?x?x?xf16>
427func.func @test_transfer_write_op_with_tensor(%base : tensor<?x?x?xf16>,
428    %offset0 : index, %offset1: index, %offset2: index) -> tensor<?x?x?xf16> {
429  %vcf0 = arith.constant dense<0.000000e+00> : vector<4x2xf16>
430  %res = vector.transfer_write %vcf0, %base[%offset0, %offset1, %offset2] { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : vector<4x2xf16>, tensor<?x?x?xf16>
431  return %res : tensor<?x?x?xf16>
432}
433
434module attributes {transform.with_named_sequence} {
435  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
436    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
437    transform.apply_patterns to %0 {
438      transform.apply_patterns.memref.extract_address_computations
439    } : !transform.any_op
440    transform.yield
441  }
442}
443
444