xref: /llvm-project/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir (revision fe55c34d19628304e0ca6a0e14a0b786b93d0e02)
1// DEFINE: %{entry_point} = za0_d_f64
2// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm
3// DEFINE: %{run} = %mcr_aarch64_cmd \
4// DEFINE:  -march=aarch64 -mattr=+sve,+sme \
5// DEFINE:  -e %{entry_point} -entry-point-result=i32 \
6// DEFINE:  -shared-libs=%native_mlir_runner_utils,%native_mlir_c_runner_utils,%native_arm_sme_abi_shlib
7
8// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=CHECK-ZA0_D
9
10// REDEFINE: %{entry_point} = load_store_two_za_s_tiles
11// RUN: %{compile} | %{run} | FileCheck %s
12
13// Integration tests demonstrating load/store to/from SME ZA tile.
14
15// This test verifies a 64-bit element ZA with FP64 data is correctly
16// loaded/stored to/from memory.
17func.func @za0_d_f64() -> i32 {
18  %c0 = arith.constant 0 : index
19  %c0_f64 = arith.constant 0.0 : f64
20  %c1_f64 = arith.constant 1.0 : f64
21  %c1_index = arith.constant 1 : index
22
23  // "svl" refers to the Streaming Vector Length and "svl_d" the number of
24  // 64-bit elements in a vector of SVL bits.
25  %svl_d = arm_sme.streaming_vl <double>
26
27  // Allocate "mem1" and fill each "row" with row number.
28  //
29  // For example, assuming an SVL of 256-bits:
30  //
31  //   0.1, 0.1, 0.1, 0.1
32  //   1.1, 1.1, 1.1, 1.1
33  //   2.1, 2.1, 2.1, 2.1
34  //   3.1, 3.1, 3.1, 3.1
35  //
36  %tilesize = arith.muli %svl_d, %svl_d : index
37  %mem1 = memref.alloca(%tilesize) : memref<?xf64>
38  %init_0 = arith.constant 0.1 : f64
39  scf.for %i = %c0 to %tilesize step %svl_d iter_args(%val = %init_0) -> (f64) {
40    %splat_val = vector.broadcast %val : f64 to vector<[2]xf64>
41    vector.store %splat_val, %mem1[%i] : memref<?xf64>, vector<[2]xf64>
42    %val_next = arith.addf %val, %c1_f64 : f64
43    scf.yield %val_next : f64
44  }
45
46  // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
47  // 2x2xi64.
48  //
49  // CHECK-ZA0_D:      ( 0.1, 0.1
50  // CHECK-ZA0_D-NEXT: ( 1.1, 1.1
51  scf.for %i = %c0 to %tilesize step %svl_d {
52    %tileslice = vector.load %mem1[%i] : memref<?xf64>, vector<[2]xf64>
53    vector.print %tileslice : vector<[2]xf64>
54  }
55
56  // Load ZA0.D from "mem1"
57  %za0_d = vector.load %mem1[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
58
59  // Allocate "mem2" to store ZA0.D to
60  %mem2 = memref.alloca(%tilesize) : memref<?xf64>
61
62  // Zero "mem2"
63  scf.for %i = %c0 to %tilesize step %c1_index {
64    memref.store %c0_f64, %mem2[%i] : memref<?xf64>
65  }
66
67  // Verify "mem2" is zeroed by doing an add reduction with initial value of
68  // zero
69  %init_0_f64 = arith.constant 0.0 : f64
70  %add_reduce = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_0_f64) -> (f64) {
71    %row = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>
72
73    %inner_add_reduce = scf.for %offset = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_0_f64) -> (f64) {
74      %t = vector.extractelement %row[%offset : index] : vector<[2]xf64>
75      %inner_add_reduce_next = arith.addf %inner_iter, %t : f64
76      scf.yield %inner_add_reduce_next : f64
77    }
78
79    %add_reduce_next = arith.addf %iter, %inner_add_reduce : f64
80    scf.yield %add_reduce_next : f64
81  }
82
83  // CHECK-ZA0_D: 0
84  vector.print %add_reduce : f64
85
86  // Dump zeroed "mem2". The smallest SVL is 128-bits so the tile will be at
87  // least 2x2xi64.
88  //
89  // CHECK-ZA0_D-NEXT: ( 0, 0
90  // CHECK-ZA0_D-NEXT: ( 0, 0
91  scf.for %i = %c0 to %tilesize step %svl_d {
92    %tileslice = vector.load %mem2[%i] : memref<?xf64>, vector<[2]xf64>
93    vector.print %tileslice : vector<[2]xf64>
94  }
95
96  // Verify "mem1" != "mem2"
97  %init_1 = arith.constant 1 : i64
98  %mul_reduce_0 = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_1) -> (i64) {
99    %row_1 = vector.load %mem1[%vnum] : memref<?xf64>, vector<[2]xf64>
100    %row_2 = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>
101    %cmp = arith.cmpf one, %row_1, %row_2 : vector<[2]xf64>
102
103    %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
104      %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1>
105      %t_i64 = arith.extui %t : i1 to i64
106      %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
107      scf.yield %inner_mul_reduce_next : i64
108    }
109
110    %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64
111    scf.yield %mul_reduce_next : i64
112  }
113
114  // CHECK-ZA0_D: 1
115  vector.print %mul_reduce_0 : i64
116
117  // Store ZA0.D to "mem2"
118  vector.store %za0_d, %mem2[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
119
120  // Verify "mem1" == "mem2"
121  %mul_reduce_1 = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_1) -> (i64) {
122    %row_1 = vector.load %mem1[%vnum] : memref<?xf64>, vector<[2]xf64>
123    %row_2 = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>
124    %cmp = arith.cmpf oeq, %row_1, %row_2 : vector<[2]xf64>
125
126    %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
127      %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1>
128      %t_i64 = arith.extui %t : i1 to i64
129      %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
130      scf.yield %inner_mul_reduce_next : i64
131    }
132
133    %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64
134    scf.yield %mul_reduce_next : i64
135  }
136
137  // CHECK-ZA0_D-NEXT: 1
138  vector.print %mul_reduce_1 : i64
139
140  // Dump "mem2". The smallest SVL is 128-bits so the tile will be at least
141  // 2x2xi64.
142  //
143  // CHECK-ZA0_D-NEXT: ( 0.1, 0.1
144  // CHECK-ZA0_D-NEXT: ( 1.1, 1.1
145  scf.for %i = %c0 to %tilesize step %svl_d {
146    %tileslice = vector.load %mem2[%i] : memref<?xf64>, vector<[2]xf64>
147    vector.print %tileslice : vector<[2]xf64>
148  }
149
150  %c0_i32 = arith.constant 0 : i32
151  return %c0_i32 : i32
152}
153
154// This test loads two 32-bit element ZA tiles from memory and stores them back
155// to memory in reverse order. This verifies the memref indices for the vector
156// load and store are correctly preserved since the second tile is offset from
157// the first tile.
158func.func @load_store_two_za_s_tiles() -> i32 {
159  %c0 = arith.constant 0 : index
160  %c0_i32 = arith.constant 0 : i32
161  %c1_i32 = arith.constant 1 : i32
162  %c2_i32 = arith.constant 2 : i32
163  %c1_index = arith.constant 1 : index
164  %c2_index = arith.constant 2 : index
165
166  // "svl" refers to the Streaming Vector Length and "svl_s" can mean either:
167  // * the number of 32-bit elements in a vector of SVL bits.
168  // * the number of tile slices (1d vectors) in a 32-bit element tile.
169  %svl_s = arm_sme.streaming_vl <word>
170
171  // Allocate memory for two 32-bit element tiles.
172  %size_of_tile = arith.muli %svl_s, %svl_s : index
173  %size_of_two_tiles = arith.muli %size_of_tile, %c2_index : index
174  %mem1 = memref.alloca(%size_of_two_tiles) : memref<?xi32>
175
176  // Fill memory that tile 1 will be loaded from with '1' and '2' for tile 2.
177  //
178  // For example, assuming an SVL of 128-bits and two 4x4xi32 tiles:
179  //
180  // tile 1
181  //
182  //   1, 1, 1, 1
183  //   1, 1, 1, 1
184  //   1, 1, 1, 1
185  //   1, 1, 1, 1
186  //
187  // tile 2
188  //
189  //   2, 2, 2, 2
190  //   2, 2, 2, 2
191  //   2, 2, 2, 2
192  //   2, 2, 2, 2
193  //
194  scf.for %i = %c0 to %size_of_two_tiles step %svl_s {
195    %isFirstTile = arith.cmpi ult, %i, %size_of_tile : index
196    %val = scf.if %isFirstTile -> i32 {
197      scf.yield %c1_i32 : i32
198    } else {
199      scf.yield %c2_i32 : i32
200    }
201    %splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
202    vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
203  }
204
205  // Dump "mem1". The smallest SVL is 128-bits so each tile will be at least
206  // 4x4xi32.
207  //
208  // CHECK:      ( 1, 1, 1, 1
209  // CHECK-NEXT: ( 1, 1, 1, 1
210  // CHECK-NEXT: ( 1, 1, 1, 1
211  // CHECK-NEXT: ( 1, 1, 1, 1
212  // CHECK:      ( 2, 2, 2, 2
213  // CHECK-NEXT: ( 2, 2, 2, 2
214  // CHECK-NEXT: ( 2, 2, 2, 2
215  // CHECK-NEXT: ( 2, 2, 2, 2
216  scf.for %i = %c0 to %size_of_two_tiles step %svl_s {
217    %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
218    vector.print %tileslice : vector<[4]xi32>
219  }
220
221  // Load tile 1 from memory
222  %za0_s = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
223
224  // Load tile 2 from memory
225  %za1_s = vector.load %mem1[%size_of_tile] : memref<?xi32>, vector<[4]x[4]xi32>
226
227  // Allocate new memory to store tiles to
228  %mem2 = memref.alloca(%size_of_two_tiles)  : memref<?xi32>
229
230  // Zero new memory
231  scf.for %i = %c0 to %size_of_two_tiles step %c1_index {
232    memref.store %c0_i32, %mem2[%i] : memref<?xi32>
233  }
234
235  // Stores tiles back to (new) memory in reverse order
236
237  // Store tile 2 to memory
238  vector.store %za1_s, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
239
240  // Store tile 1 to memory
241  vector.store %za0_s, %mem2[%size_of_tile] : memref<?xi32>, vector<[4]x[4]xi32>
242
243  // Dump "mem2" and check the tiles were stored in reverse order. The smallest
244  // SVL is 128-bits so the tiles will be at least 4x4xi32.
245  //
246  // CHECK:      TILE BEGIN
247  // CHECK-NEXT: ( 2, 2, 2, 2
248  // CHECK-NEXT: ( 2, 2, 2, 2
249  // CHECK-NEXT: ( 2, 2, 2, 2
250  // CHECK-NEXT: ( 2, 2, 2, 2
251  // CHECK:      TILE END
252  // CHECK-NEXT: TILE BEGIN
253  // CHECK-NEXT: ( 1, 1, 1, 1
254  // CHECK-NEXT: ( 1, 1, 1, 1
255  // CHECK-NEXT: ( 1, 1, 1, 1
256  // CHECK-NEXT: ( 1, 1, 1, 1
257  // CHECK:      TILE END
258  vector.print str "TILE BEGIN\n"
259  scf.for %i = %c0 to %size_of_two_tiles step %svl_s {
260    %av = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
261    vector.print %av : vector<[4]xi32>
262
263    %tileSizeMinusStep = arith.subi %size_of_tile, %svl_s : index
264    %isNextTile = arith.cmpi eq, %i, %tileSizeMinusStep : index
265    scf.if %isNextTile {
266      vector.print str "TILE END\n"
267      vector.print str "TILE BEGIN\n"
268    }
269  }
270  vector.print str "TILE END\n"
271
272  return %c0_i32 : i32
273}
274
275