1// DEFINE: %{entry_point} = za0_d_f64 2// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm 3// DEFINE: %{run} = %mcr_aarch64_cmd \ 4// DEFINE: -march=aarch64 -mattr=+sve,+sme \ 5// DEFINE: -e %{entry_point} -entry-point-result=i32 \ 6// DEFINE: -shared-libs=%native_mlir_runner_utils,%native_mlir_c_runner_utils,%native_arm_sme_abi_shlib 7 8// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=CHECK-ZA0_D 9 10// REDEFINE: %{entry_point} = load_store_two_za_s_tiles 11// RUN: %{compile} | %{run} | FileCheck %s 12 13// Integration tests demonstrating load/store to/from SME ZA tile. 14 15// This test verifies a 64-bit element ZA with FP64 data is correctly 16// loaded/stored to/from memory. 17func.func @za0_d_f64() -> i32 { 18 %c0 = arith.constant 0 : index 19 %c0_f64 = arith.constant 0.0 : f64 20 %c1_f64 = arith.constant 1.0 : f64 21 %c1_index = arith.constant 1 : index 22 23 // "svl" refers to the Streaming Vector Length and "svl_d" the number of 24 // 64-bit elements in a vector of SVL bits. 25 %svl_d = arm_sme.streaming_vl <double> 26 27 // Allocate "mem1" and fill each "row" with row number. 28 // 29 // For example, assuming an SVL of 256-bits: 30 // 31 // 0.1, 0.1, 0.1, 0.1 32 // 1.1, 1.1, 1.1, 1.1 33 // 2.1, 2.1, 2.1, 2.1 34 // 3.1, 3.1, 3.1, 3.1 35 // 36 %tilesize = arith.muli %svl_d, %svl_d : index 37 %mem1 = memref.alloca(%tilesize) : memref<?xf64> 38 %init_0 = arith.constant 0.1 : f64 39 scf.for %i = %c0 to %tilesize step %svl_d iter_args(%val = %init_0) -> (f64) { 40 %splat_val = vector.broadcast %val : f64 to vector<[2]xf64> 41 vector.store %splat_val, %mem1[%i] : memref<?xf64>, vector<[2]xf64> 42 %val_next = arith.addf %val, %c1_f64 : f64 43 scf.yield %val_next : f64 44 } 45 46 // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least 47 // 2x2xi64. 48 // 49 // CHECK-ZA0_D: ( 0.1, 0.1 50 // CHECK-ZA0_D-NEXT: ( 1.1, 1.1 51 scf.for %i = %c0 to %tilesize step %svl_d { 52 %tileslice = vector.load %mem1[%i] : memref<?xf64>, vector<[2]xf64> 53 vector.print %tileslice : vector<[2]xf64> 54 } 55 56 // Load ZA0.D from "mem1" 57 %za0_d = vector.load %mem1[%c0] : memref<?xf64>, vector<[2]x[2]xf64> 58 59 // Allocate "mem2" to store ZA0.D to 60 %mem2 = memref.alloca(%tilesize) : memref<?xf64> 61 62 // Zero "mem2" 63 scf.for %i = %c0 to %tilesize step %c1_index { 64 memref.store %c0_f64, %mem2[%i] : memref<?xf64> 65 } 66 67 // Verify "mem2" is zeroed by doing an add reduction with initial value of 68 // zero 69 %init_0_f64 = arith.constant 0.0 : f64 70 %add_reduce = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_0_f64) -> (f64) { 71 %row = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64> 72 73 %inner_add_reduce = scf.for %offset = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_0_f64) -> (f64) { 74 %t = vector.extractelement %row[%offset : index] : vector<[2]xf64> 75 %inner_add_reduce_next = arith.addf %inner_iter, %t : f64 76 scf.yield %inner_add_reduce_next : f64 77 } 78 79 %add_reduce_next = arith.addf %iter, %inner_add_reduce : f64 80 scf.yield %add_reduce_next : f64 81 } 82 83 // CHECK-ZA0_D: 0 84 vector.print %add_reduce : f64 85 86 // Dump zeroed "mem2". The smallest SVL is 128-bits so the tile will be at 87 // least 2x2xi64. 88 // 89 // CHECK-ZA0_D-NEXT: ( 0, 0 90 // CHECK-ZA0_D-NEXT: ( 0, 0 91 scf.for %i = %c0 to %tilesize step %svl_d { 92 %tileslice = vector.load %mem2[%i] : memref<?xf64>, vector<[2]xf64> 93 vector.print %tileslice : vector<[2]xf64> 94 } 95 96 // Verify "mem1" != "mem2" 97 %init_1 = arith.constant 1 : i64 98 %mul_reduce_0 = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_1) -> (i64) { 99 %row_1 = vector.load %mem1[%vnum] : memref<?xf64>, vector<[2]xf64> 100 %row_2 = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64> 101 %cmp = arith.cmpf one, %row_1, %row_2 : vector<[2]xf64> 102 103 %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { 104 %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1> 105 %t_i64 = arith.extui %t : i1 to i64 106 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 107 scf.yield %inner_mul_reduce_next : i64 108 } 109 110 %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64 111 scf.yield %mul_reduce_next : i64 112 } 113 114 // CHECK-ZA0_D: 1 115 vector.print %mul_reduce_0 : i64 116 117 // Store ZA0.D to "mem2" 118 vector.store %za0_d, %mem2[%c0] : memref<?xf64>, vector<[2]x[2]xf64> 119 120 // Verify "mem1" == "mem2" 121 %mul_reduce_1 = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_1) -> (i64) { 122 %row_1 = vector.load %mem1[%vnum] : memref<?xf64>, vector<[2]xf64> 123 %row_2 = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64> 124 %cmp = arith.cmpf oeq, %row_1, %row_2 : vector<[2]xf64> 125 126 %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { 127 %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1> 128 %t_i64 = arith.extui %t : i1 to i64 129 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 130 scf.yield %inner_mul_reduce_next : i64 131 } 132 133 %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64 134 scf.yield %mul_reduce_next : i64 135 } 136 137 // CHECK-ZA0_D-NEXT: 1 138 vector.print %mul_reduce_1 : i64 139 140 // Dump "mem2". The smallest SVL is 128-bits so the tile will be at least 141 // 2x2xi64. 142 // 143 // CHECK-ZA0_D-NEXT: ( 0.1, 0.1 144 // CHECK-ZA0_D-NEXT: ( 1.1, 1.1 145 scf.for %i = %c0 to %tilesize step %svl_d { 146 %tileslice = vector.load %mem2[%i] : memref<?xf64>, vector<[2]xf64> 147 vector.print %tileslice : vector<[2]xf64> 148 } 149 150 %c0_i32 = arith.constant 0 : i32 151 return %c0_i32 : i32 152} 153 154// This test loads two 32-bit element ZA tiles from memory and stores them back 155// to memory in reverse order. This verifies the memref indices for the vector 156// load and store are correctly preserved since the second tile is offset from 157// the first tile. 158func.func @load_store_two_za_s_tiles() -> i32 { 159 %c0 = arith.constant 0 : index 160 %c0_i32 = arith.constant 0 : i32 161 %c1_i32 = arith.constant 1 : i32 162 %c2_i32 = arith.constant 2 : i32 163 %c1_index = arith.constant 1 : index 164 %c2_index = arith.constant 2 : index 165 166 // "svl" refers to the Streaming Vector Length and "svl_s" can mean either: 167 // * the number of 32-bit elements in a vector of SVL bits. 168 // * the number of tile slices (1d vectors) in a 32-bit element tile. 169 %svl_s = arm_sme.streaming_vl <word> 170 171 // Allocate memory for two 32-bit element tiles. 172 %size_of_tile = arith.muli %svl_s, %svl_s : index 173 %size_of_two_tiles = arith.muli %size_of_tile, %c2_index : index 174 %mem1 = memref.alloca(%size_of_two_tiles) : memref<?xi32> 175 176 // Fill memory that tile 1 will be loaded from with '1' and '2' for tile 2. 177 // 178 // For example, assuming an SVL of 128-bits and two 4x4xi32 tiles: 179 // 180 // tile 1 181 // 182 // 1, 1, 1, 1 183 // 1, 1, 1, 1 184 // 1, 1, 1, 1 185 // 1, 1, 1, 1 186 // 187 // tile 2 188 // 189 // 2, 2, 2, 2 190 // 2, 2, 2, 2 191 // 2, 2, 2, 2 192 // 2, 2, 2, 2 193 // 194 scf.for %i = %c0 to %size_of_two_tiles step %svl_s { 195 %isFirstTile = arith.cmpi ult, %i, %size_of_tile : index 196 %val = scf.if %isFirstTile -> i32 { 197 scf.yield %c1_i32 : i32 198 } else { 199 scf.yield %c2_i32 : i32 200 } 201 %splat_val = vector.broadcast %val : i32 to vector<[4]xi32> 202 vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32> 203 } 204 205 // Dump "mem1". The smallest SVL is 128-bits so each tile will be at least 206 // 4x4xi32. 207 // 208 // CHECK: ( 1, 1, 1, 1 209 // CHECK-NEXT: ( 1, 1, 1, 1 210 // CHECK-NEXT: ( 1, 1, 1, 1 211 // CHECK-NEXT: ( 1, 1, 1, 1 212 // CHECK: ( 2, 2, 2, 2 213 // CHECK-NEXT: ( 2, 2, 2, 2 214 // CHECK-NEXT: ( 2, 2, 2, 2 215 // CHECK-NEXT: ( 2, 2, 2, 2 216 scf.for %i = %c0 to %size_of_two_tiles step %svl_s { 217 %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32> 218 vector.print %tileslice : vector<[4]xi32> 219 } 220 221 // Load tile 1 from memory 222 %za0_s = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32> 223 224 // Load tile 2 from memory 225 %za1_s = vector.load %mem1[%size_of_tile] : memref<?xi32>, vector<[4]x[4]xi32> 226 227 // Allocate new memory to store tiles to 228 %mem2 = memref.alloca(%size_of_two_tiles) : memref<?xi32> 229 230 // Zero new memory 231 scf.for %i = %c0 to %size_of_two_tiles step %c1_index { 232 memref.store %c0_i32, %mem2[%i] : memref<?xi32> 233 } 234 235 // Stores tiles back to (new) memory in reverse order 236 237 // Store tile 2 to memory 238 vector.store %za1_s, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32> 239 240 // Store tile 1 to memory 241 vector.store %za0_s, %mem2[%size_of_tile] : memref<?xi32>, vector<[4]x[4]xi32> 242 243 // Dump "mem2" and check the tiles were stored in reverse order. The smallest 244 // SVL is 128-bits so the tiles will be at least 4x4xi32. 245 // 246 // CHECK: TILE BEGIN 247 // CHECK-NEXT: ( 2, 2, 2, 2 248 // CHECK-NEXT: ( 2, 2, 2, 2 249 // CHECK-NEXT: ( 2, 2, 2, 2 250 // CHECK-NEXT: ( 2, 2, 2, 2 251 // CHECK: TILE END 252 // CHECK-NEXT: TILE BEGIN 253 // CHECK-NEXT: ( 1, 1, 1, 1 254 // CHECK-NEXT: ( 1, 1, 1, 1 255 // CHECK-NEXT: ( 1, 1, 1, 1 256 // CHECK-NEXT: ( 1, 1, 1, 1 257 // CHECK: TILE END 258 vector.print str "TILE BEGIN\n" 259 scf.for %i = %c0 to %size_of_two_tiles step %svl_s { 260 %av = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32> 261 vector.print %av : vector<[4]xi32> 262 263 %tileSizeMinusStep = arith.subi %size_of_tile, %svl_s : index 264 %isNextTile = arith.cmpi eq, %i, %tileSizeMinusStep : index 265 scf.if %isNextTile { 266 vector.print str "TILE END\n" 267 vector.print str "TILE BEGIN\n" 268 } 269 } 270 vector.print str "TILE END\n" 271 272 return %c0_i32 : i32 273} 274 275