xref: /llvm-project/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir (revision c42512436b23ab50e7637f239abe8371407104a1)
1// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
2
3//===----------------------------------------------------------------------===//
4// vector.transfer_read
5//===----------------------------------------------------------------------===//
6
7// CHECK-LABEL: @transfer_read_2d_i8
8// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
9func.func @transfer_read_2d_i8(%src : memref<?x?xi8>) {
10  %c0 = arith.constant 0 : index
11  %pad = arith.constant 0 : i8
12  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
13  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
14  return
15}
16
17// -----
18
19// CHECK-LABEL: @transfer_read_2d_i16
20// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi16>, vector<[8]x[8]xi16>
21func.func @transfer_read_2d_i16(%src : memref<?x?xi16>) {
22  %c0 = arith.constant 0 : index
23  %pad = arith.constant 0 : i16
24  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
25  "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
26  return
27}
28
29// -----
30
31// CHECK-LABEL: @transfer_read_2d_i32
32// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi32>, vector<[4]x[4]xi32>
33func.func @transfer_read_2d_i32(%src : memref<?x?xi32>) {
34  %c0 = arith.constant 0 : index
35  %pad = arith.constant 0 : i32
36  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
37  "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
38  return
39}
40
41// -----
42
43// CHECK-LABEL: @transfer_read_2d_i64
44// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi64>, vector<[2]x[2]xi64>
45func.func @transfer_read_2d_i64(%src : memref<?x?xi64>) {
46  %c0 = arith.constant 0 : index
47  %pad = arith.constant 0 : i64
48  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64>
49  "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
50  return
51}
52
53// -----
54
55// CHECK-LABEL: @transfer_read_2d_i128
56// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi128>, vector<[1]x[1]xi128>
57func.func @transfer_read_2d_i128(%src : memref<?x?xi128>) {
58  %c0 = arith.constant 0 : index
59  %pad = arith.constant 0 : i128
60  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128>
61  "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
62  return
63}
64
65// -----
66
67// CHECK-LABEL: @transfer_read_2d_f16
68// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf16>, vector<[8]x[8]xf16>
69func.func @transfer_read_2d_f16(%src : memref<?x?xf16>) {
70  %c0 = arith.constant 0 : index
71  %pad = arith.constant 0.0 : f16
72  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16>
73  "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
74  return
75}
76
77// -----
78
79// CHECK-LABEL: @transfer_read_2d_bf16
80// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
81func.func @transfer_read_2d_bf16(%src : memref<?x?xbf16>) {
82  %c0 = arith.constant 0 : index
83  %pad = arith.constant 0.0 : bf16
84  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
85  "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
86  return
87}
88
89// -----
90
91// CHECK-LABEL: @transfer_read_2d_f32
92// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf32>, vector<[4]x[4]xf32>
93func.func @transfer_read_2d_f32(%src : memref<?x?xf32>) {
94  %c0 = arith.constant 0 : index
95  %pad = arith.constant 0.0 : f32
96  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
97  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
98  return
99}
100
101// -----
102
103// CHECK-LABEL: @transfer_read_2d_f64
104// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf64>, vector<[2]x[2]xf64>
105func.func @transfer_read_2d_f64(%src : memref<?x?xf64>) {
106  %c0 = arith.constant 0 : index
107  %pad = arith.constant 0.0 : f64
108  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
109  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
110  return
111}
112
113// -----
114
115// CHECK-LABEL: @transfer_read_2d_with_mask_i16
116// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
117func.func @transfer_read_2d_with_mask_i16(%src : memref<?x?xi16>, %mask : vector<[8]x[8]xi1>) {
118  %c0 = arith.constant 0 : index
119  %pad = arith.constant 0 : i16
120  %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
121  "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
122  return
123}
124
125// -----
126
127/// in-flight transpose
128
129// CHECK-LABEL: @transfer_read_2d_transpose_i8
130// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
131func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
132  %c0 = arith.constant 0 : index
133  %pad = arith.constant 0 : i8
134  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
135  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
136  return
137}
138
139// -----
140
141// CHECK-LABEL: @transfer_read_2d_transpose_with_mask_f32
142// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
143func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mask : vector<[4]x[4]xi1>) {
144  %c0 = arith.constant 0 : index
145  %pad = arith.constant 0.0 : f32
146  %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
147  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
148  return
149}
150
151// -----
152
153// CHECK-LABEL: @fold_transpose_into_load
154// CHECK-NOT: arm_sme.tile_store
155// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
156// CHECK-NOT: arm_sme.tile_store
157func.func @fold_transpose_into_load(%src : memref<?x?xf32>) {
158  %c0 = arith.constant 0 : index
159  %pad = arith.constant 0.0 : f32
160  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
161  %1 = vector.transpose %0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
162  "prevent.dce"(%1) : (vector<[4]x[4]xf32>) -> ()
163}
164
165// -----
166
167/// Transposes with more than a single use cannot be folded into load and will
168/// instead be transposed via memory.
169
170// CHECK-LABEL: @fold_transpose_into_load_multi_use
171// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
172// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
173// CHECK: %[[TILE_TRANSPOSED_VIA_MEM:.*]] = arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
174// CHECK: "prevent.dce"(%[[TILE_TRANSPOSED_VIA_MEM]]) : (vector<[4]x[4]xf32>) -> ()
175func.func @fold_transpose_into_load_multi_use(%src : memref<?x?xf32>) {
176  %c0 = arith.constant 0 : index
177  %pad = arith.constant 0.0 : f32
178  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
179  "test.some_use"(%0) : (vector<[4]x[4]xf32>) -> ()
180  %1 = vector.transpose %0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
181  "prevent.dce"(%1) : (vector<[4]x[4]xf32>) -> ()
182}
183
184// -----
185
186//===----------------------------------------------------------------------===//
187// vector.transfer_write
188//===----------------------------------------------------------------------===//
189
190// CHECK-LABEL: func.func @transfer_write_2d_i8(
191// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[16]x[16]xi8>,
192// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xi8>) {
193// CHECK:         %[[C0:.*]] = arith.constant 0 : index
194// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
195func.func @transfer_write_2d_i8(%vector : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
196  %c0 = arith.constant 0 : index
197  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
198  return
199}
200
201// -----
202
203// CHECK-LABEL: func.func @transfer_write_2d_i16(
204// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[8]x[8]xi16>,
205// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xi16>) {
206// CHECK:         %[[C0:.*]] = arith.constant 0 : index
207// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi16>, vector<[8]x[8]xi16>
208func.func @transfer_write_2d_i16(%vector : vector<[8]x[8]xi16>, %dest : memref<?x?xi16>) {
209  %c0 = arith.constant 0 : index
210  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi16>, memref<?x?xi16>
211  return
212}
213
214// -----
215
216// CHECK-LABEL: func.func @transfer_write_2d_i32(
217// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[4]x[4]xi32>,
218// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xi32>) {
219// CHECK:         %[[C0:.*]] = arith.constant 0 : index
220// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
221func.func @transfer_write_2d_i32(%vector : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
222  %c0 = arith.constant 0 : index
223  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xi32>, memref<?x?xi32>
224  return
225}
226
227// -----
228
229// CHECK-LABEL: func.func @transfer_write_2d_i64(
230// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
231// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xi64>) {
232// CHECK:         %[[C0:.*]] = arith.constant 0 : index
233// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi64>, vector<[2]x[2]xi64>
234func.func @transfer_write_2d_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
235  %c0 = arith.constant 0 : index
236  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
237  return
238}
239
240// -----
241
242// CHECK-LABEL: func.func @transfer_write_2d_f16(
243// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[8]x[8]xf16>,
244// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xf16>) {
245// CHECK:         %[[C0:.*]] = arith.constant 0 : index
246// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf16>, vector<[8]x[8]xf16>
247func.func @transfer_write_2d_f16(%vector : vector<[8]x[8]xf16>, %dest : memref<?x?xf16>) {
248  %c0 = arith.constant 0 : index
249  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
250  return
251}
252
253// -----
254
255// CHECK-LABEL: func.func @transfer_write_2d_bf16(
256// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
257// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xbf16>) {
258// CHECK:         %[[C0:.*]] = arith.constant 0 : index
259// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
260func.func @transfer_write_2d_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>) {
261  %c0 = arith.constant 0 : index
262  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
263  return
264}
265
266// -----
267
268// CHECK-LABEL: func.func @transfer_write_2d_f32(
269// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[4]x[4]xf32>,
270// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xf32>) {
271// CHECK:         %[[C0:.*]] = arith.constant 0 : index
272// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf32>, vector<[4]x[4]xf32>
273func.func @transfer_write_2d_f32(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
274  %c0 = arith.constant 0 : index
275  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
276  return
277}
278
279// -----
280
281// CHECK-LABEL: func.func @transfer_write_2d_f64(
282// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[2]x[2]xf64>,
283// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xf64>) {
284// CHECK:         %[[C0:.*]] = arith.constant 0 : index
285// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf64>, vector<[2]x[2]xf64>
286func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>) {
287  %c0 = arith.constant 0 : index
288  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[2]x[2]xf64>, memref<?x?xf64>
289  return
290}
291
292// -----
293
294// CHECK-LABEL: func.func @transfer_write_2d_with_mask_f64(
295// CHECK-SAME:                                             %[[VECTOR:.*]]: vector<[2]x[2]xf64>,
296// CHECK-SAME:                                             %[[DEST:.*]]: memref<?x?xf64>,
297// CHECK-SAME:                                             %[[MASK:.*]]: vector<[2]x[2]xi1>) {
298// CHECK:         %[[C0:.*]] = arith.constant 0 : index
299// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] : memref<?x?xf64>, vector<[2]x[2]xf64>
300func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
301  %c0 = arith.constant 0 : index
302  vector.transfer_write %vector, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[2]x[2]xf64>, memref<?x?xf64>
303  return
304}
305
306// -----
307
308/// in-flight transpose via vertical store.
309
310// CHECK-LABEL: func.func @transfer_write_2d_transpose_i64(
311// CHECK-SAME:                                             %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
312// CHECK-SAME:                                             %[[DEST:.*]]: memref<?x?xi64>) {
313// CHECK:         %[[C0:.*]] = arith.constant 0 : index
314// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
315func.func @transfer_write_2d_transpose_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
316  %c0 = arith.constant 0 : index
317  vector.transfer_write %vector, %dest[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
318  return
319}
320
321// -----
322
323/// in-flight transpose via vertical store with mask.
324
325// CHECK-LABEL: func.func @transfer_write_2d_transpose_with_mask_bf16(
326// CHECK-SAME:                                                        %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
327// CHECK-SAME:                                                        %[[DEST:.*]]: memref<?x?xbf16>,
328// CHECK-SAME:                                                        %[[MASK:.*]]: vector<[8]x[8]xi1>) {
329// CHECK:         %[[C0:.*]] = arith.constant 0 : index
330// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
331func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>, %mask : vector<[8]x[8]xi1>) {
332  %c0 = arith.constant 0 : index
333  vector.transfer_write %vector, %dest[%c0, %c0], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
334  return
335}
336
337// -----
338
339// CHECK-LABEL: func.func @transfer_write_slice(
340// CHECK-SAME:                                  %[[VECTOR:.*]]: vector<[4]x[4]xf32>,
341// CHECK-SAME:                                  %[[DEST:.*]]: memref<?x?xf32>,
342// CHECK-SAME:                                  %[[INDEX:.*]]: index) {
343// CHECK:         %[[C0:.*]] = arith.constant 0 : index
344// CHECK:         %[[MASK:.*]] = arith.constant dense<true> : vector<[4]xi1>
345// CHECK:         arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
346func.func @transfer_write_slice(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) {
347  %c0 = arith.constant 0 : index
348  %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
349  vector.transfer_write %slice, %dest[%slice_index, %c0] { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
350  return
351}
352
353// -----
354
355// CHECK-LABEL: func.func @transfer_write_slice_with_mask(
356// CHECK-SAME:                                            %[[VECTOR:.*]]: vector<[4]x[4]xf32>,
357// CHECK-SAME:                                            %[[DEST:.*]]: memref<?x?xf32>,
358// CHECK-SAME:                                            %[[MASK:.*]]: vector<[4]xi1>,
359// CHECK-SAME:                                            %[[INDEX:.*]]: index) {
360// CHECK:         %[[C0:.*]] = arith.constant 0 : index
361// CHECK:         arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
362func.func @transfer_write_slice_with_mask(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %mask: vector<[4]xi1>, %slice_index: index) {
363  %c0 = arith.constant 0 : index
364  %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
365  vector.transfer_write %slice, %dest[%slice_index, %c0], %mask { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
366  return
367}
368
369// -----
370
371// CHECK-LABEL: func.func @transfer_write_vertical_slice
372// CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical>
373func.func @transfer_write_vertical_slice(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) {
374  %c0 = arith.constant 0 : index
375   %slice = arm_sme.extract_tile_slice %vector[%slice_index] layout<vertical>
376            : vector<[4]xf32> from vector<[4]x[4]xf32>
377  vector.transfer_write %slice, %dest[%slice_index, %c0] { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
378  return
379}
380
381//===----------------------------------------------------------------------===//
382// vector.broadcast
383//===----------------------------------------------------------------------===//
384
385// -----
386
387// CHECK-LABEL:   func.func @broadcast_vec2d_from_i32(
388// CHECK-SAME:                                        %[[SRC:.*]]: i32) {
389// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
390// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
391// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
392// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
393// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
394// CHECK: %[[VSCALE:.*]] = vector.vscale
395// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
396// CHECK: %[[TILE:.*]] = scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[4]x[4]xi32>) {
397// CHECK:   %[[NEW_TILE:.*]] = arm_sme.insert_tile_slice %[[SRC_1D]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
398// CHECK:   scf.yield %[[NEW_TILE]] : vector<[4]x[4]xi32>
399// CHECK: "prevent.dce"(%[[TILE]]) : (vector<[4]x[4]xi32>) -> ()
400func.func @broadcast_vec2d_from_i32(%arg0: i32) {
401  %0 = vector.broadcast %arg0 : i32 to vector<[4]x[4]xi32>
402  "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
403  return
404}
405
406// -----
407
408// CHECK-LABEL:   func.func @broadcast_vec2d_from_vec0d(
409// CHECK-SAME:                                          %[[SRC:.*]]: vector<f32>) {
410// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : vector<f32> to vector<[4]xf32>
411// CHECK: scf.for
412// CHECK:   arm_sme.insert_tile_slice %[[SRC_1D]], {{.*}}
413func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) {
414  %0 = vector.broadcast %arg0 : vector<f32> to vector<[4]x[4]xf32>
415  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
416  return
417}
418
419// -----
420
421// CHECK-LABEL:   func.func @broadcast_vec2d_from_vec1d(
422// CHECK-SAME:                                          %[[SRC:.*]]: vector<[8]xi16>) {
423// CHECK-NOT: vector.broadcast
424// CHECK: scf.for
425// CHECK:   arm_sme.insert_tile_slice %[[SRC]], {{.*}}
426func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
427  %0 = vector.broadcast %arg0 : vector<[8]xi16> to vector<[8]x[8]xi16>
428  "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
429  return
430}
431
432//===----------------------------------------------------------------------===//
433// vector.splat
434//===----------------------------------------------------------------------===//
435
436// -----
437
438// CHECK-LABEL:   func.func @splat_vec2d_from_i32(
439// CHECK-SAME:      %[[SRC:.*]]: i32) {
440// CHECK:   %[[BCST:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
441// CHECK:   arm_sme.get_tile : vector<[4]x[4]xi32>
442// CHECK:   %[[VSCALE:.*]] = vector.vscale
443// CHECK:   %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
444// CHECK:   scf.for {{.*}} to %[[NUM_TILE_SLICES]] {{.*}} {
445// CHECK:     arm_sme.insert_tile_slice %[[BCST]], {{.*}} : vector<[4]xi32> into vector<[4]x[4]xi32>
446func.func @splat_vec2d_from_i32(%arg0: i32) {
447  %0 = vector.splat %arg0 : vector<[4]x[4]xi32>
448  "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
449  return
450}
451
452// -----
453
454// CHECK-LABEL:   func.func @splat_vec2d_from_f16(
455// CHECK-SAME:      %[[SRC:.*]]: f16) {
456// CHECK:   %[[BCST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<[8]xf16>
457// CHECK:   scf.for
458// CHECK:     arm_sme.insert_tile_slice %[[BCST]], {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16>
459func.func @splat_vec2d_from_f16(%arg0: f16) {
460  %0 = vector.splat %arg0 : vector<[8]x[8]xf16>
461  "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
462  return
463}
464
465//===----------------------------------------------------------------------===//
466// vector.transpose
467//===----------------------------------------------------------------------===//
468
469// -----
470
471// CHECK-LABEL:   func.func @transpose_i8(
472// CHECK-SAME:                            %[[TILE:.*]]: vector<[16]x[16]xi8>)
473// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
474// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
475// CHECK:           %[[VSCALE:.*]] = vector.vscale
476// CHECK:           %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
477// CHECK:           %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
478// CHECK:           arm_sme.tile_store %[[TILE]], %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
479// CHECK:           arm_sme.tile_load %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
480func.func @transpose_i8(%arg0: vector<[16]x[16]xi8>) {
481  %0 = vector.transpose %arg0, [1, 0] : vector<[16]x[16]xi8> to vector<[16]x[16]xi8>
482  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
483  return
484}
485
486// -----
487
488// CHECK-LABEL: @transpose_i16
489// CHECK: arith.constant 8
490// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
491// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
492func.func @transpose_i16(%arg0: vector<[8]x[8]xi16>) {
493  %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xi16> to vector<[8]x[8]xi16>
494  "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
495  return
496}
497
498// -----
499
500// CHECK-LABEL: @transpose_i32
501// CHECK: arith.constant 4
502// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
503// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
504func.func @transpose_i32(%arg0: vector<[4]x[4]xi32>) {
505  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
506  "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
507  return
508}
509
510// -----
511
512// CHECK-LABEL: @transpose_i64
513// CHECK: arith.constant 2
514// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
515// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
516func.func @transpose_i64(%arg0: vector<[2]x[2]xi64>) {
517  %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xi64> to vector<[2]x[2]xi64>
518  "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
519  return
520}
521
522// -----
523
524// CHECK-LABEL: @transpose_i128
525// CHECK: %[[VSCALE:.*]] = vector.vscale
526// CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[VSCALE]], %[[VSCALE]]) : memref<?x?xi128>
527// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
528// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
529func.func @transpose_i128(%arg0: vector<[1]x[1]xi128>) {
530  %0 = vector.transpose %arg0, [1, 0] : vector<[1]x[1]xi128> to vector<[1]x[1]xi128>
531  "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
532  return
533}
534
535// -----
536
537// CHECK-LABEL: @transpose_f16
538// CHECK: arith.constant 8
539// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
540// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
541func.func @transpose_f16(%arg0: vector<[8]x[8]xf16>) {
542  %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xf16> to vector<[8]x[8]xf16>
543  "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
544  return
545}
546
547// -----
548
549// CHECK-LABEL: @transpose_bf16
550// CHECK: arith.constant 8
551// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
552// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
553func.func @transpose_bf16(%arg0: vector<[8]x[8]xbf16>) {
554  %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xbf16> to vector<[8]x[8]xbf16>
555  "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
556  return
557}
558
559// -----
560
561// CHECK-LABEL: @transpose_f32
562// CHECK: arith.constant 4
563// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
564// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
565func.func @transpose_f32(%arg0: vector<[4]x[4]xf32>) {
566  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
567  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
568  return
569}
570
571// -----
572
573// CHECK-LABEL: @transpose_f64
574// CHECK: arith.constant 2
575// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
576// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
577func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
578  %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xf64> to vector<[2]x[2]xf64>
579  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
580  return
581}
582
583//===----------------------------------------------------------------------===//
584// vector.outerproduct
585//===----------------------------------------------------------------------===//
586
587// -----
588
589// CHECK-LABEL: @vector_outerproduct_masked_f16
590// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
591func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0 : index, %dim1 : index) {
592  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
593  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[8]xi1>
594  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[8]xi1>
595  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16>
596  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16>
597  "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
598}
599
600// -----
601
602// CHECK-LABEL: @vector_outerproduct_masked_bf16
603// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
604func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0 : index, %dim1 : index) {
605  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
606  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[8]xi1>
607  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[8]xi1>
608  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xbf16>, vector<[8]xbf16>
609  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16>
610  "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
611}
612
613// -----
614
615// CHECK-LABEL: @vector_outerproduct_masked_f32
616// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
617func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0 : index, %dim1 : index) {
618  %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
619  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
620  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
621  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[4]xf32>, vector<[4]xf32>
622  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
623  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
624}
625
626// -----
627
628// CHECK-LABEL: @vector_outerproduct_masked_f64
629// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
630func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0 : index, %dim1 : index) {
631  %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
632  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[2]xi1>
633  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[2]xi1>
634  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[2]xf64>, vector<[2]xf64>
635  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
636  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
637}
638
639// -----
640
641// CHECK-LABEL: @vector_outerproduct_f16
642// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>
643func.func @vector_outerproduct_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
644  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[8]xf16>, vector<[8]xf16>
645  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
646  "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
647}
648
649// -----
650
651// CHECK-LABEL: @vector_outerproduct_bf16
652// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>
653func.func @vector_outerproduct_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
654  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[8]xbf16>, vector<[8]xbf16>
655  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
656  "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
657}
658
659// -----
660
661// CHECK-LABEL: @vector_outerproduct_f32
662// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>
663func.func @vector_outerproduct_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
664  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[4]xf32>, vector<[4]xf32>
665  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
666  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
667}
668
669// -----
670
671// CHECK-LABEL: @vector_outerproduct_f64
672// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>
673func.func @vector_outerproduct_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
674  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[2]xf64>, vector<[2]xf64>
675  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
676  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
677}
678
679//===----------------------------------------------------------------------===//
680// vector.print
681//===----------------------------------------------------------------------===//
682
683// -----
684
685func.func @vector_print_tile(%tile: vector<[4]x[4]xf32>)
686{
687  vector.print %tile : vector<[4]x[4]xf32>
688  return
689}
690// CHECK-LABEL:   func.func @vector_print_tile(
691// CHECK-SAME:                                  %[[TILE:.*]]: vector<[4]x[4]xf32>) {
692// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
693// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
694// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
695// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
696// CHECK-DAG:     %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
697// CHECK-NEXT:      scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
698// CHECK-NEXT:        %[[TILE_SLICE:.*]] = arm_sme.extract_tile_slice %[[TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
699// CHECK-NEXT:        vector.print %[[TILE_SLICE]] : vector<[4]xf32>
700
701//===----------------------------------------------------------------------===//
702// vector.load
703//===----------------------------------------------------------------------===//
704
705// -----
706
707// CHECK-LABEL: @vector_load_i8_with_offset(
708// CHECK-SAME:                              %[[MEMREF:.*]]: memref<?x?xi8>)
709// CHECK: %[[C0:.*]] = arith.constant 0 : index
710// CHECK: %[[C123:.*]] = arith.constant 123 : index
711// CHECK: arm_sme.tile_load %[[MEMREF]][%[[C123]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
712func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
713  %c0 = arith.constant 0 : index
714  %c123 = arith.constant 123 : index
715  %tile = vector.load %arg0[%c123, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
716  return %tile : vector<[16]x[16]xi8>
717}
718
719// -----
720
721// CHECK-LABEL: @vector_load_i8_from_rank_1_memref(
722// CHECK-SAME:                                     %[[MEMREF:.*]]: memref<?xi8>)
723// CHECK: %[[C0:.*]] = arith.constant 0 : index
724// CHECK: arm_sme.tile_load %[[MEMREF]][%[[C0]]] : memref<?xi8>, vector<[16]x[16]xi8>
725func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16]x[16]xi8> {
726  %c0 = arith.constant 0 : index
727  %tile = vector.load %arg0[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
728  return %tile : vector<[16]x[16]xi8>
729}
730
731// -----
732
733// CHECK-LABEL: @vector_load_i16(
734// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
735func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
736  %c0 = arith.constant 0 : index
737  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
738  return %tile : vector<[8]x[8]xi16>
739}
740
741// -----
742
743// CHECK-LABEL: @vector_load_i32(
744// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
745func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
746  %c0 = arith.constant 0 : index
747  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
748  return %tile : vector<[4]x[4]xi32>
749}
750
751// -----
752
753// CHECK-LABEL: @vector_load_i64(
754// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
755func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
756  %c0 = arith.constant 0 : index
757  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
758  return %tile : vector<[2]x[2]xi64>
759}
760
761// -----
762
763// CHECK-LABEL: @vector_load_f16(
764// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
765func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
766  %c0 = arith.constant 0 : index
767  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
768  return %tile : vector<[8]x[8]xf16>
769}
770
771// -----
772
773// CHECK-LABEL: @vector_load_bf16(
774// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
775func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
776  %c0 = arith.constant 0 : index
777  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
778  return %tile : vector<[8]x[8]xbf16>
779}
780
781// -----
782
783// CHECK-LABEL: @vector_load_f32(
784// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
785func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
786  %c0 = arith.constant 0 : index
787  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
788  return %tile : vector<[4]x[4]xf32>
789}
790
791// -----
792
793// CHECK-LABEL: @vector_load_f64(
794// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
795func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
796  %c0 = arith.constant 0 : index
797  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
798  return %tile : vector<[2]x[2]xf64>
799}
800
801// -----
802
803// CHECK-LABEL: @vector_load_i128(
804// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
805func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
806  %c0 = arith.constant 0 : index
807  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
808  return %tile : vector<[1]x[1]xi128>
809}
810
811
812//===----------------------------------------------------------------------===//
813// vector.store
814//===----------------------------------------------------------------------===//
815
816// -----
817
818// CHECK-LABEL: @vector_store_i8(
819// CHECK-SAME:                   %[[MEMREF:.*]]: memref<?x?xi8>) {
820// CHECK: %[[C0:.*]] = arith.constant 0 : index
821// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
822// CHECK: arm_sme.tile_store %[[TILE]], %[[MEMREF]][%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
823func.func @vector_store_i8(%arg0 : memref<?x?xi8>) {
824  %c0 = arith.constant 0 : index
825  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
826  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
827  return
828}
829
830// -----
831
832// CHECK-LABEL: @vector_store_i16
833// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
834func.func @vector_store_i16(%arg0 : memref<?x?xi16>) {
835  %c0 = arith.constant 0 : index
836  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
837  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
838  return
839}
840
841// -----
842
843// CHECK-LABEL: @vector_store_i32
844// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
845func.func @vector_store_i32(%arg0 : memref<?x?xi32>) {
846  %c0 = arith.constant 0 : index
847  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
848  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
849  return
850}
851
852// -----
853
854// CHECK-LABEL: @vector_store_i64
855// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
856func.func @vector_store_i64(%arg0 : memref<?x?xi64>) {
857  %c0 = arith.constant 0 : index
858  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
859  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
860  return
861}
862
863// -----
864
865// CHECK-LABEL: @vector_store_f16
866// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
867func.func @vector_store_f16(%arg0 : memref<?x?xf16>) {
868  %c0 = arith.constant 0 : index
869  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
870  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
871  return
872}
873
874// -----
875
876// CHECK-LABEL: @vector_store_bf16
877// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
878func.func @vector_store_bf16(%arg0 : memref<?x?xbf16>) {
879  %c0 = arith.constant 0 : index
880  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
881  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
882  return
883}
884// -----
885
886// CHECK-LABEL: @vector_store_f32
887// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
888func.func @vector_store_f32(%arg0 : memref<?x?xf32>) {
889  %c0 = arith.constant 0 : index
890  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
891  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
892  return
893}
894
895// -----
896
897// CHECK-LABEL: @vector_store_f64
898// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
899func.func @vector_store_f64(%arg0 : memref<?x?xf64>) {
900  %c0 = arith.constant 0 : index
901  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
902  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
903  return
904}
905
906// -----
907
908// CHECK-LABEL: @vector_store_i128
909// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
910func.func @vector_store_i128(%arg0 : memref<?x?xi128>) {
911  %c0 = arith.constant 0 : index
912  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
913  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
914  return
915}
916
917//===----------------------------------------------------------------------===//
918// vector.insert
919//===----------------------------------------------------------------------===//
920
921// -----
922
923// CHECK-LABEL: @vector_insert_slice_i32(
924// CHECK-SAME:                       %[[SLICE:.*]]: vector<[4]xi32>,
925// CHECK-SAME:                       %[[INDEX:.*]]: index)
926func.func @vector_insert_slice_i32(%slice: vector<[4]xi32>, %row: index) -> vector<[4]x[4]xi32>{
927  // CHECK-NEXT: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
928  // CHECK-NEXT: arm_sme.insert_tile_slice %[[SLICE]], %[[TILE]][%[[INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
929  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
930  %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xi32> into vector<[4]x[4]xi32>
931  return %new_tile : vector<[4]x[4]xi32>
932}
933
934// -----
935
936// CHECK-LABEL: @vector_insert_slice_i8
937func.func @vector_insert_slice_i8(%slice: vector<[16]xi8>, %row: index) -> vector<[16]x[16]xi8> {
938  // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[16]xi8> into vector<[16]x[16]xi8>
939  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
940  %new_tile = vector.insert %slice, %tile[%row] : vector<[16]xi8> into vector<[16]x[16]xi8>
941  return %new_tile : vector<[16]x[16]xi8>
942}
943
944// -----
945
946// CHECK-LABEL: @vector_insert_slice_i16
947func.func @vector_insert_slice_i16(%slice: vector<[8]xi16>, %row: index) -> vector<[8]x[8]xi16> {
948  // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[8]xi16> into vector<[8]x[8]xi16>
949  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
950  %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xi16> into vector<[8]x[8]xi16>
951  return %new_tile : vector<[8]x[8]xi16>
952}
953
954// -----
955
956// CHECK-LABEL: @vector_insert_slice_i64
957func.func @vector_insert_slice_i64(%slice: vector<[2]xi64>, %row: index) -> vector<[2]x[2]xi64> {
958  // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[2]xi64> into vector<[2]x[2]xi64>
959  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
960  %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xi64> into vector<[2]x[2]xi64>
961  return %new_tile : vector<[2]x[2]xi64>
962}
963
964// -----
965
966// CHECK-LABEL: @vector_insert_slice_i128
967func.func @vector_insert_slice_i128(%slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> {
968  // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[1]xi128> into vector<[1]x[1]xi128>
969  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
970  %new_tile = vector.insert %slice, %tile[%row] : vector<[1]xi128> into vector<[1]x[1]xi128>
971  return %new_tile : vector<[1]x[1]xi128>
972}
973
974// -----
975
976// CHECK-LABEL: @vector_insert_slice_f16
977func.func @vector_insert_slice_f16(%slice: vector<[8]xf16>, %row: index) -> vector<[8]x[8]xf16> {
978  // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16>
979  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
980  %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xf16> into vector<[8]x[8]xf16>
981  return %new_tile : vector<[8]x[8]xf16>
982}
983
984// -----
985
986// CHECK-LABEL: @vector_insert_slice_bf16
987func.func @vector_insert_slice_bf16(%slice: vector<[8]xbf16>, %row: index) -> vector<[8]x[8]xbf16> {
988  // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[8]xbf16> into vector<[8]x[8]xbf16>
989  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
990  %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xbf16> into vector<[8]x[8]xbf16>
991  return %new_tile : vector<[8]x[8]xbf16>
992}
993
994// -----
995
996// CHECK-LABEL: @vector_insert_slice_f32
997func.func @vector_insert_slice_f32(%slice: vector<[4]xf32>, %row: index) -> vector<[4]x[4]xf32> {
998  // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[4]xf32> into vector<[4]x[4]xf32>
999  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
1000  %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xf32> into vector<[4]x[4]xf32>
1001  return %new_tile : vector<[4]x[4]xf32>
1002}
1003
1004// -----
1005
1006// CHECK-LABEL: @vector_insert_slice_f64
1007func.func @vector_insert_slice_f64(%slice: vector<[2]xf64>, %row: index) -> vector<[2]x[2]xf64> {
1008  // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[2]xf64> into vector<[2]x[2]xf64>
1009  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
1010  %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xf64> into vector<[2]x[2]xf64>
1011  return %new_tile : vector<[2]x[2]xf64>
1012}
1013
1014// -----
1015
1016// CHECK-LABEL: @vector_insert_element_i32(
1017// CHECK-SAME:                         %[[EL:.*]]: i32,
1018// CHECK-SAME:                         %[[ROW:.*]]: index,
1019// CHECK-SAME:                         %[[COL:.*]]: index)
1020func.func @vector_insert_element_i32(%el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> {
1021  // CHECK-NEXT: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
1022  // CHECK-NEXT: %[[SLICE:.*]] = arm_sme.extract_tile_slice %[[TILE]][%[[ROW]]] : vector<[4]xi32> from vector<[4]x[4]xi32>
1023  // CHECK-NEXT: %[[NEW_SLICE:.*]] = vector.insert %[[EL]], %[[SLICE]] [%[[COL]]] : i32 into vector<[4]xi32>
1024  // CHECK-NEXT: arm_sme.insert_tile_slice %[[NEW_SLICE]], %[[TILE]][%[[ROW]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
1025  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
1026  %new_tile = vector.insert %el, %tile[%row, %col] : i32 into vector<[4]x[4]xi32>
1027  return %new_tile : vector<[4]x[4]xi32>
1028}
1029
1030// -----
1031
1032// CHECK-LABEL: @vector_insert_element_i8
1033func.func @vector_insert_element_i8(%el: i8, %row: index, %col: index) -> vector<[16]x[16]xi8> {
1034  // CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
1035  // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[16]xi8> from vector<[16]x[16]xi8>
1036  // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[16]xi8> into vector<[16]x[16]xi8>
1037  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
1038  %new_tile = vector.insert %el, %tile[%row, %col] : i8 into vector<[16]x[16]xi8>
1039  return %new_tile : vector<[16]x[16]xi8>
1040}
1041
1042// -----
1043
1044// CHECK-LABEL: @vector_insert_element_i16
1045func.func @vector_insert_element_i16(%el: i16, %row: index, %col: index) -> vector<[8]x[8]xi16> {
1046  // CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[8]x[8]xi16>
1047  // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[8]xi16> from vector<[8]x[8]xi16>
1048  // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[8]xi16> into vector<[8]x[8]xi16>
1049  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
1050  %new_tile = vector.insert %el, %tile[%row, %col] : i16 into vector<[8]x[8]xi16>
1051  return %new_tile : vector<[8]x[8]xi16>
1052}
1053
1054// -----
1055
1056// CHECK-LABEL: @vector_insert_element_i64
1057func.func @vector_insert_element_i64(%el: i64, %row: index, %col: index) -> vector<[2]x[2]xi64> {
1058  // CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xi64>
1059  // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[2]xi64> from vector<[2]x[2]xi64>
1060  // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[2]xi64> into vector<[2]x[2]xi64>
1061  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
1062  %new_tile = vector.insert %el, %tile[%row, %col] : i64 into vector<[2]x[2]xi64>
1063  return %new_tile : vector<[2]x[2]xi64>
1064}
1065
1066// -----
1067
1068// CHECK-LABEL: @vector_insert_element_i128
1069func.func @vector_insert_element_i128(%el: i128, %row: index, %col: index) -> vector<[1]x[1]xi128> {
1070  // CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[1]x[1]xi128>
1071  // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[1]xi128> from vector<[1]x[1]xi128>
1072  // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[1]xi128> into vector<[1]x[1]xi128>
1073  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
1074  %new_tile = vector.insert %el, %tile[%row, %col] : i128 into vector<[1]x[1]xi128>
1075  return %new_tile : vector<[1]x[1]xi128>
1076}
1077
1078// -----
1079
1080// CHECK-LABEL: @vector_insert_element_f16
1081func.func @vector_insert_element_f16(%el: f16, %row: index, %col: index) -> vector<[8]x[8]xf16> {
1082  // CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[8]x[8]xf16>
1083  // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[8]xf16> from vector<[8]x[8]xf16>
1084  // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[8]xf16> into vector<[8]x[8]xf16>
1085  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
1086  %new_tile = vector.insert %el, %tile[%row, %col] : f16 into vector<[8]x[8]xf16>
1087  return %new_tile : vector<[8]x[8]xf16>
1088}
1089
1090// -----
1091
1092// CHECK-LABEL: @vector_insert_element_bf16
1093func.func @vector_insert_element_bf16(%el: bf16, %row: index, %col: index) -> vector<[8]x[8]xbf16> {
1094  // CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[8]x[8]xbf16>
1095  // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[8]xbf16> from vector<[8]x[8]xbf16>
1096  // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[8]xbf16> into vector<[8]x[8]xbf16>
1097  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
1098  %new_tile = vector.insert %el, %tile[%row, %col] : bf16 into vector<[8]x[8]xbf16>
1099  return %new_tile : vector<[8]x[8]xbf16>
1100}
1101
1102// -----
1103
1104// CHECK-LABEL: @vector_insert_element_f32
1105func.func @vector_insert_element_f32(%el: f32, %row: index, %col: index) -> vector<[4]x[4]xf32> {
1106  // CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xf32>
1107  // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[4]xf32> from vector<[4]x[4]xf32>
1108  // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[4]xf32> into vector<[4]x[4]xf32>
1109  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
1110  %new_tile = vector.insert %el, %tile[%row, %col] : f32 into vector<[4]x[4]xf32>
1111  return %new_tile : vector<[4]x[4]xf32>
1112}
1113
1114// -----
1115
1116// CHECK-LABEL: @vector_insert_element_f64
1117func.func @vector_insert_element_f64(%el: f64, %row: index, %col: index) -> vector<[2]x[2]xf64> {
1118  // CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
1119  // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[2]xf64> from vector<[2]x[2]xf64>
1120  // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[2]xf64> into vector<[2]x[2]xf64>
1121  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
1122  %new_tile = vector.insert %el, %tile[%row, %col] : f64 into vector<[2]x[2]xf64>
1123  return %new_tile : vector<[2]x[2]xf64>
1124}
1125
1126//===----------------------------------------------------------------------===//
1127// vector.extract --> arm_sme.extract_tile_slice
1128//===----------------------------------------------------------------------===//
1129
1130// -----
1131
1132// CHECK-LABEL: @vector_extract_slice_i32(
1133// CHECK-SAME:                            %[[INDEX:.*]]: index)
1134func.func @vector_extract_slice_i32(%row: index) -> vector<[4]xi32> {
1135  // CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
1136  // CHECK: arm_sme.extract_tile_slice %[[TILE]][%[[INDEX]]] : vector<[4]xi32> from vector<[4]x[4]xi32>
1137  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
1138  %slice = vector.extract %tile[%row] : vector<[4]xi32> from vector<[4]x[4]xi32>
1139  return %slice : vector<[4]xi32>
1140}
1141
1142// -----
1143
1144// CHECK-LABEL: @vector_extract_slice_i8
1145func.func @vector_extract_slice_i8(%row: index) -> vector<[16]xi8> {
1146  // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[16]xi8> from vector<[16]x[16]xi8>
1147  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
1148  %slice = vector.extract %tile[%row] : vector<[16]xi8> from vector<[16]x[16]xi8>
1149  return %slice : vector<[16]xi8>
1150}
1151
1152// -----
1153
1154// CHECK-LABEL: @vector_extract_slice_i16
1155func.func @vector_extract_slice_i16(%row: index) -> vector<[8]xi16> {
1156  // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[8]xi16> from vector<[8]x[8]xi16>
1157  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
1158  %slice = vector.extract %tile[%row] : vector<[8]xi16> from vector<[8]x[8]xi16>
1159  return %slice : vector<[8]xi16>
1160}
1161
1162// -----
1163
1164// CHECK-LABEL: @vector_extract_slice_i64
1165func.func @vector_extract_slice_i64(%row: index) -> vector<[2]xi64> {
1166  // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[2]xi64> from vector<[2]x[2]xi64>
1167  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
1168  %slice = vector.extract %tile[%row] : vector<[2]xi64> from vector<[2]x[2]xi64>
1169  return %slice : vector<[2]xi64>
1170}
1171
1172// -----
1173
1174// CHECK-LABEL: @vector_extract_slice_i128
1175func.func @vector_extract_slice_i128(%row: index) -> vector<[1]xi128> {
1176  // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[1]xi128> from vector<[1]x[1]xi128>
1177  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
1178  %slice = vector.extract %tile[%row] : vector<[1]xi128> from vector<[1]x[1]xi128>
1179  return %slice : vector<[1]xi128>
1180}
1181
1182// -----
1183
1184// CHECK-LABEL: @vector_extract_slice_f16
1185func.func @vector_extract_slice_f16(%row: index) -> vector<[8]xf16> {
1186  // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[8]xf16> from vector<[8]x[8]xf16>
1187  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
1188  %slice = vector.extract %tile[%row] : vector<[8]xf16> from vector<[8]x[8]xf16>
1189  return %slice : vector<[8]xf16>
1190}
1191
1192// -----
1193
1194// CHECK-LABEL: @vector_extract_slice_bf16
1195func.func @vector_extract_slice_bf16(%row: index) -> vector<[8]xbf16> {
1196  // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[8]xbf16> from vector<[8]x[8]xbf16>
1197  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
1198  %slice = vector.extract %tile[%row] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
1199  return %slice : vector<[8]xbf16>
1200}
1201
1202// -----
1203
1204// CHECK-LABEL: @vector_extract_slice_f32
1205func.func @vector_extract_slice_f32(%row: index) -> vector<[4]xf32> {
1206  // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[4]xf32> from vector<[4]x[4]xf32>
1207  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
1208  %slice = vector.extract %tile[%row] : vector<[4]xf32> from vector<[4]x[4]xf32>
1209  return %slice : vector<[4]xf32>
1210}
1211
1212// -----
1213
1214// CHECK-LABEL: @vector_extract_slice_f64
1215func.func @vector_extract_slice_f64(%row: index) -> vector<[2]xf64> {
1216  // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[2]xf64> from vector<[2]x[2]xf64>
1217  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
1218  %slice = vector.extract %tile[%row] : vector<[2]xf64> from vector<[2]x[2]xf64>
1219  return %slice : vector<[2]xf64>
1220}
1221
1222// -----
1223
1224// CHECK-LABEL: @vector_extract_element(
1225// CHECK-SAME:                          %[[ROW:.*]]: index,
1226// CHECK-SAME:                          %[[COL:.*]]: index)
1227func.func @vector_extract_element(%row: index, %col: index) -> i32 {
1228  // CHECK-NEXT: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
1229  // CHECK-NEXT: %[[SLICE:.*]] = arm_sme.extract_tile_slice %[[TILE]][%[[ROW]]] : vector<[4]xi32> from vector<[4]x[4]xi32>
1230  // CHECK-NEXT: %[[EL:.*]] = vector.extract %[[SLICE]]{{\[}}%[[COL]]] : i32 from vector<[4]xi32>
1231  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
1232  %el = vector.extract %tile[%row, %col] : i32 from vector<[4]x[4]xi32>
1233  return %el : i32
1234}
1235
1236// -----
1237
1238// CHECK-LABEL: @vector_extract_element_i8
1239func.func @vector_extract_element_i8(%row: index, %col: index) -> i8 {
1240  // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[16]xi8> from vector<[16]x[16]xi8>
1241  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i8 from vector<[16]xi8>
1242  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
1243  %el = vector.extract %tile[%row, %col] : i8 from vector<[16]x[16]xi8>
1244  return %el : i8
1245}
1246
1247// -----
1248
1249// CHECK-LABEL: @vector_extract_element_i16
1250func.func @vector_extract_element_i16(%row: index, %col: index) -> i16 {
1251  // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[8]xi16> from vector<[8]x[8]xi16>
1252  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i16 from vector<[8]xi16>
1253  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
1254  %el = vector.extract %tile[%row, %col] : i16 from vector<[8]x[8]xi16>
1255  return %el : i16
1256}
1257
1258// -----
1259
1260// CHECK-LABEL: @vector_extract_element_i64
1261func.func @vector_extract_element_i64(%row: index, %col: index) -> i64 {
1262  // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[2]xi64> from vector<[2]x[2]xi64>
1263  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i64 from vector<[2]xi64>
1264  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
1265  %el = vector.extract %tile[%row, %col] : i64 from vector<[2]x[2]xi64>
1266  return %el : i64
1267}
1268
1269// -----
1270
1271// CHECK-LABEL: @vector_extract_element_i128
1272func.func @vector_extract_element_i128(%row: index, %col: index) -> i128 {
1273  // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[1]xi128> from vector<[1]x[1]xi128>
1274  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i128 from vector<[1]xi128>
1275  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
1276  %el = vector.extract %tile[%row, %col] : i128 from vector<[1]x[1]xi128>
1277  return %el : i128
1278}
1279
1280// -----
1281
1282// CHECK-LABEL: @vector_extract_element_f16
1283func.func @vector_extract_element_f16(%row: index, %col: index) -> f16 {
1284  // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[8]xf16> from vector<[8]x[8]xf16>
1285  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f16 from vector<[8]xf16>
1286  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
1287  %el = vector.extract %tile[%row, %col] : f16 from vector<[8]x[8]xf16>
1288  return %el : f16
1289}
1290
1291// -----
1292
1293// CHECK-LABEL: @vector_extract_element_bf16
1294func.func @vector_extract_element_bf16(%row: index, %col: index) -> bf16 {
1295  // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[8]xbf16> from vector<[8]x[8]xbf16>
1296  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : bf16 from vector<[8]xbf16>
1297  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
1298  %el = vector.extract %tile[%row, %col] : bf16 from vector<[8]x[8]xbf16>
1299  return %el : bf16
1300}
1301
1302// -----
1303
1304// CHECK-LABEL: @vector_extract_element_f32
1305func.func @vector_extract_element_f32(%row: index, %col: index) -> f32 {
1306  // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[4]xf32> from vector<[4]x[4]xf32>
1307  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f32 from vector<[4]xf32>
1308  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
1309  %el = vector.extract %tile[%row, %col] : f32 from vector<[4]x[4]xf32>
1310  return %el : f32
1311}
1312
1313// -----
1314
1315// CHECK-LABEL: @vector_extract_element_f64
1316func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 {
1317  // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[2]xf64> from vector<[2]x[2]xf64>
1318  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f64 from vector<[2]xf64>
1319  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
1320  %el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
1321  return %el : f64
1322}
1323
1324//===----------------------------------------------------------------------===//
1325// vector.extract --> arm_sve.psel
1326//===----------------------------------------------------------------------===//
1327
1328// -----
1329
1330// CHECK-LABEL: @dynamic_vector_extract_mask_to_psel(
1331// CHECK-SAME:    %[[A:.*]]:  index, %[[B:.*]]: index, %[[INDEX:.*]]: index)
1332func.func @dynamic_vector_extract_mask_to_psel(%a: index, %b: index, %index: index) -> vector<[8]xi1>
1333{
1334  // CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[4]xi1>
1335  // CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[8]xi1>
1336  // CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[INDEX]]] : vector<[8]xi1>, vector<[4]xi1>
1337  %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
1338  %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
1339  return %slice : vector<[8]xi1>
1340}
1341
1342// -----
1343
1344// CHECK-LABEL: @vector_extract_mask_to_psel(
1345// CHECK-SAME:                               %[[A:.*]]: index,
1346// CHECK-SAME:                               %[[B:.*]]: index)
1347func.func @vector_extract_mask_to_psel(%a: index, %b: index) -> vector<[2]xi1>
1348{
1349  // CHECK: %[[C1:.*]] = arith.constant 1 : index
1350  // CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[16]xi1>
1351  // CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[2]xi1>
1352  // CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[C1]]] : vector<[2]xi1>, vector<[16]xi1>
1353  %mask = vector.create_mask %a, %b : vector<[16]x[2]xi1>
1354  %slice = vector.extract %mask[1] : vector<[2]xi1> from vector<[16]x[2]xi1>
1355  return %slice : vector<[2]xi1>
1356}
1357