1// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm="enable-x86vector" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ 2// RUN: mlir-translate --mlir-to-llvmir | \ 3// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_c_runner_utils | \ 4// RUN: FileCheck %s 5 6// This test shows how to implement a sparse vector-vector dot product with 7// AVX512. It uses vp2intersect, mask.compress and vector.contract to compute 8// the dot product of two sparse HW vectors of 8 float64 elements ("segment"). 9// Each sparse vector is represented by an index memref (A or C) and by a data 10// memref (B or D), containing M or N elements. 11// 12// There are four different implementations: 13// * `memref_dot_simple`: Simple O(N*M) implementation with two for loops. 14// * `memref_dot_optimized`: An optimized O(N*M) version of the previous 15// implementation, where the second for loop skips over some elements. 16// * `memref_dot_while`: An optimized O(N+M) implementation that utilizes a 17// single while loop, coiterating over both vectors. 18// * `memref_dot_while_branchless`: An optimized O(N+M) implementation that 19// consists of a single while loop and has no branches within the loop. 20// 21// Output of llvm-mca: 22// https://gist.github.com/matthias-springer/72e7ee1b3c467e7aefb6e1fd862e4841 23 24#contraction_accesses = [ 25 affine_map<(i) -> (i)>, 26 affine_map<(i) -> (i)>, 27 affine_map<(i) -> ()> 28] 29#contraction_trait = { 30 indexing_maps = #contraction_accesses, 31 iterator_types = ["reduction"] 32} 33 34// Sparse vector dot product of two vectors. 35func.func @vector_dot(%v_A : vector<8xi64>, %v_B : vector<8xf64>, 36 %v_C : vector<8xi64>, %v_D : vector<8xf64>) -> f64 { 37 // Compute intersection of indices. 38 %k0, %k1 = x86vector.avx512.vp2intersect %v_A, %v_C : vector<8xi64> 39 40 // Filter out values without match and compress vector. 41 %p0 = x86vector.avx512.mask.compress %k0, %v_B : vector<8xf64> 42 %p1 = x86vector.avx512.mask.compress %k1, %v_D : vector<8xf64> 43 44 // Dense vector dot product. 45 %acc = arith.constant 0.0 : f64 46 %r = vector.contract #contraction_trait %p0, %p1, %acc 47 : vector<8xf64>, vector<8xf64> into f64 48 49 return %r : f64 50} 51 52// Fill input memrefs will all zeros, so that they can be used with arbitrary 53// input sizes up to 128 elements per sparse vector. 54func.func @init_input(%m_A : memref<?xi64>, %m_B : memref<?xf64>, 55 %m_C : memref<?xi64>, %m_D : memref<?xf64>) { 56 %c0 = arith.constant 0 : index 57 %v_data = arith.constant dense<0.0> : vector<128xf64> 58 %v_index = arith.constant dense<9223372036854775807> : vector<128xi64> 59 60 vector.transfer_write %v_index, %m_A[%c0] : vector<128xi64>, memref<?xi64> 61 vector.transfer_write %v_data, %m_B[%c0] : vector<128xf64>, memref<?xf64> 62 vector.transfer_write %v_index, %m_C[%c0] : vector<128xi64>, memref<?xi64> 63 vector.transfer_write %v_data, %m_D[%c0] : vector<128xf64>, memref<?xf64> 64 65 return 66} 67 68func.func @fill_input_1(%m_A : memref<?xi64>, %m_B : memref<?xf64>, 69 %m_C : memref<?xi64>, %m_D : memref<?xf64>) 70 -> (index, index){ 71 func.call @init_input(%m_A, %m_B, %m_C, %m_D) 72 : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>) -> () 73 74 %c0 = arith.constant 0 : index 75 76 %v_A = arith.constant dense<[0, 1, 10, 12, 13, 17, 18, 21, 77 51, 52, 57, 61, 62, 82, 98, 99]> : vector<16xi64> 78 %v_B = arith.constant dense<[1., 5., 8., 3., 2., 1., 0., 9., 79 6., 7., 7., 3., 5., 2., 9., 1.]> : vector<16xf64> 80 %v_C = arith.constant dense<[1, 2, 5, 10, 11, 12, 47, 48, 81 67, 68, 69, 70, 71, 72, 77, 78, 82 79, 82, 83, 84, 85, 90, 91, 98]> : vector<24xi64> 83 %v_D = arith.constant dense<[1., 5., 8., 3., 2., 1., 2., 9., 84 6., 7., 7., 3., 5., 2., 9., 1., 85 2., 9., 8., 7., 2., 0., 0., 4.]> : vector<24xf64> 86 87 vector.transfer_write %v_A, %m_A[%c0] : vector<16xi64>, memref<?xi64> 88 vector.transfer_write %v_B, %m_B[%c0] : vector<16xf64>, memref<?xf64> 89 vector.transfer_write %v_C, %m_C[%c0] : vector<24xi64>, memref<?xi64> 90 vector.transfer_write %v_D, %m_D[%c0] : vector<24xf64>, memref<?xf64> 91 92 %M = arith.constant 16 : index 93 %N = arith.constant 24 : index 94 95 return %M, %N : index, index 96} 97 98func.func @fill_input_2(%m_A : memref<?xi64>, %m_B : memref<?xf64>, 99 %m_C : memref<?xi64>, %m_D : memref<?xf64>) 100 -> (index, index){ 101 func.call @init_input(%m_A, %m_B, %m_C, %m_D) 102 : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>) -> () 103 104 %c0 = arith.constant 0 : index 105 106 %v_A = arith.constant dense<[0, 1, 3, 5, 6, 7, 8, 9, 107 51, 52, 57, 61, 62, 63, 65, 66]> : vector<16xi64> 108 %v_B = arith.constant dense<[1., 5., 8., 3., 2., 1., 2., 9., 109 6., 7., 7., 3., 5., 2., 9., 1.]> : vector<16xf64> 110 %v_C = arith.constant dense<[6, 7, 11, 12, 15, 17, 19, 21, 111 30, 31, 33, 34, 37, 39, 40, 41, 112 42, 44, 45, 46, 47, 48, 49, 50, 113 62, 63, 64, 65, 66, 67, 68, 69, 114 70, 77, 78, 79, 81, 82, 89, 99]> : vector<40xi64> 115 %v_D = arith.constant dense<[1., 5., 8., 3., 2., 1., 2., 9., 116 6., 7., 7., 3., 5., 2., 9., 1., 117 2., 9., 8., 7., 2., 1., 2., 4., 118 4., 5., 8., 8., 2., 3., 5., 1., 119 8., 6., 6., 4., 3., 8., 9., 2.]> : vector<40xf64> 120 121 vector.transfer_write %v_A, %m_A[%c0] : vector<16xi64>, memref<?xi64> 122 vector.transfer_write %v_B, %m_B[%c0] : vector<16xf64>, memref<?xf64> 123 vector.transfer_write %v_C, %m_C[%c0] : vector<40xi64>, memref<?xi64> 124 vector.transfer_write %v_D, %m_D[%c0] : vector<40xf64>, memref<?xf64> 125 126 %M = arith.constant 16 : index 127 %N = arith.constant 40 : index 128 129 return %M, %N : index, index 130} 131 132// Simple vector dot product implementation: Intersect every segment of size 8 133// in (%m_A, %m_B) with every segment of size 8 in (%m_C, %m_D). 134func.func @memref_dot_simple(%m_A : memref<?xi64>, %m_B : memref<?xf64>, 135 %m_C : memref<?xi64>, %m_D : memref<?xf64>, 136 %M : index, %N : index) 137 -> f64 { 138 // Helper constants for loops. 139 %c0 = arith.constant 0 : index 140 %c8 = arith.constant 8 : index 141 142 %data_zero = arith.constant 0.0 : f64 143 %index_padding = arith.constant 9223372036854775807 : i64 144 145 // Notation: %sum is the current (partial) aggregated dot product sum. 146 147 %r0 = scf.for %a = %c0 to %M step %c8 148 iter_args(%sum0 = %data_zero) -> (f64) { 149 %v_A = vector.transfer_read %m_A[%a], %index_padding 150 : memref<?xi64>, vector<8xi64> 151 %v_B = vector.transfer_read %m_B[%a], %data_zero 152 : memref<?xf64>, vector<8xf64> 153 154 %r1 = scf.for %b = %c0 to %N step %c8 155 iter_args(%sum1 = %sum0) -> (f64) { 156 %v_C = vector.transfer_read %m_C[%b], %index_padding 157 : memref<?xi64>, vector<8xi64> 158 %v_D = vector.transfer_read %m_D[%b], %data_zero 159 : memref<?xf64>, vector<8xf64> 160 161 %subresult = func.call @vector_dot(%v_A, %v_B, %v_C, %v_D) 162 : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>) -> f64 163 %r2 = arith.addf %sum1, %subresult : f64 164 scf.yield %r2 : f64 165 } 166 167 scf.yield %r1 : f64 168 } 169 170 return %r0 : f64 171} 172 173// Optimized vector dot product implementation: Taking advantage of the fact 174// that indices in %m_A and %m_C are sorted ascendingly, skip over segments 175// in (%m_C, %m_D) that are know to have no intersection with the current 176// segment from (%m_A, %m_B). 177func.func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>, 178 %m_C : memref<?xi64>, %m_D : memref<?xf64>, 179 %M : index, %N : index) 180 -> f64 { 181 // Helper constants for loops. 182 %c0 = arith.constant 0 : index 183 %i0 = arith.constant 0 : i32 184 %i7 = arith.constant 7 : i32 185 %c8 = arith.constant 8 : index 186 187 %data_zero = arith.constant 0.0 : f64 188 %index_padding = arith.constant 9223372036854775807 : i64 189 190 // Notation: %sum is the current (partial) aggregated dot product sum. 191 // %j_start is the value from which the inner for loop starts iterating. This 192 // value keeps increasing if earlier segments of (%m_C, %m_D) are known to 193 // be no longer needed. 194 195 %r0, %t0 = scf.for %a = %c0 to %M step %c8 196 iter_args(%sum0 = %data_zero, %b_start0 = %c0) -> (f64, index) { 197 %v_A = vector.transfer_read %m_A[%a], %index_padding 198 : memref<?xi64>, vector<8xi64> 199 %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64> 200 201 %r1, %next_b_start0 = scf.for %b = %b_start0 to %N step %c8 202 iter_args(%sum1 = %sum0, %b_start1 = %b_start0) -> (f64, index) { 203 %v_C = vector.transfer_read %m_C[%b], %index_padding 204 : memref<?xi64>, vector<8xi64> 205 %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> 206 %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64 207 208 %r2, %next_b_start1 = scf.if %seg1_done -> (f64, index) { 209 // %v_C segment is done, no need to examine this one again (ever). 210 %next_b_start2 = arith.addi %b_start1, %c8 : index 211 scf.yield %sum1, %next_b_start2 : f64, index 212 } else { 213 %v_B = vector.transfer_read %m_B[%a], %data_zero 214 : memref<?xf64>, vector<8xf64> 215 %v_D = vector.transfer_read %m_D[%b], %data_zero 216 : memref<?xf64>, vector<8xf64> 217 218 %subresult = func.call @vector_dot(%v_A, %v_B, %v_C, %v_D) 219 : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>) 220 -> f64 221 %r3 = arith.addf %sum1, %subresult : f64 222 scf.yield %r3, %b_start1 : f64, index 223 } 224 225 scf.yield %r2, %next_b_start1 : f64, index 226 } 227 228 scf.yield %r1, %next_b_start0 : f64, index 229 } 230 231 return %r0 : f64 232} 233 234// Vector dot product with a while loop. Implemented as follows: 235// 236// r = 0.0, a = 0, b = 0 237// while (a < M && b < N) { 238// segA = A[a:a+8], segB = B[b:b+8] 239// if (segB[7] < segA[0]) b += 8 240// elif (segA[7] < segB[0]) a += 8 241// else { 242// r += vector_dot(...) 243// if (segA[7] < segB[7]) a += 8 244// elif (segB[7] < segA[7]) b += 8 245// else a += 8, b += 8 246// } 247// } 248func.func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>, 249 %m_C : memref<?xi64>, %m_D : memref<?xf64>, 250 %M : index, %N : index) 251 -> f64 { 252 // Helper constants for loops. 253 %c0 = arith.constant 0 : index 254 %i0 = arith.constant 0 : i32 255 %i7 = arith.constant 7 : i32 256 %c8 = arith.constant 8 : index 257 258 %data_zero = arith.constant 0.0 : f64 259 %index_padding = arith.constant 9223372036854775807 : i64 260 261 %r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0) 262 : (f64, index, index) -> (f64, index, index) { 263 %cond_i = arith.cmpi "slt", %a1, %M : index 264 %cond_j = arith.cmpi "slt", %b1, %N : index 265 %cond = arith.andi %cond_i, %cond_j : i1 266 scf.condition(%cond) %r1, %a1, %b1 : f64, index, index 267 } do { 268 ^bb0(%r1 : f64, %a1 : index, %b1 : index): 269 // v_A, v_B, seg*_* could be part of the loop state to avoid a few 270 // redundant reads. 271 %v_A = vector.transfer_read %m_A[%a1], %index_padding 272 : memref<?xi64>, vector<8xi64> 273 %v_C = vector.transfer_read %m_C[%b1], %index_padding 274 : memref<?xi64>, vector<8xi64> 275 276 %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64> 277 %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64> 278 %segB_min = vector.extractelement %v_C[%i0 : i32] : vector<8xi64> 279 %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> 280 281 %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64 282 %r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) { 283 %b3 = arith.addi %b1, %c8 : index 284 scf.yield %r1, %a1, %b3 : f64, index, index 285 } else { 286 %seg0_done = arith.cmpi "slt", %segA_max, %segB_min : i64 287 %r4, %a4, %b4 = scf.if %seg0_done -> (f64, index, index) { 288 %a5 = arith.addi %a1, %c8 : index 289 scf.yield %r1, %a5, %b1 : f64, index, index 290 } else { 291 %v_B = vector.transfer_read %m_B[%a1], %data_zero 292 : memref<?xf64>, vector<8xf64> 293 %v_D = vector.transfer_read %m_D[%b1], %data_zero 294 : memref<?xf64>, vector<8xf64> 295 296 %subresult = func.call @vector_dot(%v_A, %v_B, %v_C, %v_D) 297 : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>) 298 -> f64 299 %r6 = arith.addf %r1, %subresult : f64 300 301 %incr_a = arith.cmpi "slt", %segA_max, %segB_max : i64 302 %a6, %b6 = scf.if %incr_a -> (index, index) { 303 %a7 = arith.addi %a1, %c8 : index 304 scf.yield %a7, %b1 : index, index 305 } else { 306 %incr_b = arith.cmpi "slt", %segB_max, %segA_max : i64 307 %a8, %b8 = scf.if %incr_b -> (index, index) { 308 %b9 = arith.addi %b1, %c8 : index 309 scf.yield %a1, %b9 : index, index 310 } else { 311 %a10 = arith.addi %a1, %c8 : index 312 %b10 = arith.addi %b1, %c8 : index 313 scf.yield %a10, %b10 : index, index 314 } 315 scf.yield %a8, %b8 : index, index 316 } 317 scf.yield %r6, %a6, %b6 : f64, index, index 318 } 319 scf.yield %r4, %a4, %b4 : f64, index, index 320 } 321 scf.yield %r2, %a2, %b2 : f64, index, index 322 } 323 324 return %r0 : f64 325} 326 327// Vector dot product with a while loop that has no branches (apart from the 328// while loop itself). Implemented as follows: 329// 330// r = 0.0, a = 0, b = 0 331// while (a < M && b < N) { 332// segA = A[a:a+8], segB = B[b:b+8] 333// r += vector_dot(...) 334// a += (segA[7] <= segB[7]) * 8 335// b += (segB[7] <= segA[7]) * 8 336// } 337func.func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64>, 338 %m_C : memref<?xi64>, %m_D : memref<?xf64>, 339 %M : index, %N : index) 340 -> f64 { 341 // Helper constants for loops. 342 %c0 = arith.constant 0 : index 343 %i7 = arith.constant 7 : i32 344 %c8 = arith.constant 8 : index 345 346 %data_zero = arith.constant 0.0 : f64 347 %index_padding = arith.constant 9223372036854775807 : i64 348 349 %r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0) 350 : (f64, index, index) -> (f64, index, index) { 351 %cond_i = arith.cmpi "slt", %a1, %M : index 352 %cond_j = arith.cmpi "slt", %b1, %N : index 353 %cond = arith.andi %cond_i, %cond_j : i1 354 scf.condition(%cond) %r1, %a1, %b1 : f64, index, index 355 } do { 356 ^bb0(%r1 : f64, %a1 : index, %b1 : index): 357 // v_A, v_B, seg*_* could be part of the loop state to avoid a few 358 // redundant reads. 359 %v_A = vector.transfer_read %m_A[%a1], %index_padding 360 : memref<?xi64>, vector<8xi64> 361 %v_B = vector.transfer_read %m_B[%a1], %data_zero 362 : memref<?xf64>, vector<8xf64> 363 %v_C = vector.transfer_read %m_C[%b1], %index_padding 364 : memref<?xi64>, vector<8xi64> 365 %v_D = vector.transfer_read %m_D[%b1], %data_zero 366 : memref<?xf64>, vector<8xf64> 367 368 %subresult = func.call @vector_dot(%v_A, %v_B, %v_C, %v_D) 369 : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>) 370 -> f64 371 %r2 = arith.addf %r1, %subresult : f64 372 373 %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64> 374 %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> 375 376 %cond_a = arith.cmpi "sle", %segA_max, %segB_max : i64 377 %cond_a_i64 = arith.extui %cond_a : i1 to i64 378 %cond_a_idx = arith.index_cast %cond_a_i64 : i64 to index 379 %incr_a = arith.muli %cond_a_idx, %c8 : index 380 %a2 = arith.addi %a1, %incr_a : index 381 382 %cond_b = arith.cmpi "sle", %segB_max, %segA_max : i64 383 %cond_b_i64 = arith.extui %cond_b : i1 to i64 384 %cond_b_idx = arith.index_cast %cond_b_i64 : i64 to index 385 %incr_b = arith.muli %cond_b_idx, %c8 : index 386 %b2 = arith.addi %b1, %incr_b : index 387 388 scf.yield %r2, %a2, %b2 : f64, index, index 389 } 390 391 return %r0 : f64 392} 393 394func.func @entry() -> i32 { 395 // Initialize large buffers that can be used for multiple test cases of 396 // different sizes. 397 %b_A = memref.alloc() : memref<128xi64> 398 %b_B = memref.alloc() : memref<128xf64> 399 %b_C = memref.alloc() : memref<128xi64> 400 %b_D = memref.alloc() : memref<128xf64> 401 402 %m_A = memref.cast %b_A : memref<128xi64> to memref<?xi64> 403 %m_B = memref.cast %b_B : memref<128xf64> to memref<?xf64> 404 %m_C = memref.cast %b_C : memref<128xi64> to memref<?xi64> 405 %m_D = memref.cast %b_D : memref<128xf64> to memref<?xf64> 406 407 // --- Test case 1 ---. 408 // M and N must be a multiple of 8 if smaller than 128. 409 // (Because padding kicks in only for out-of-bounds accesses.) 410 %M1, %N1 = func.call @fill_input_1(%m_A, %m_B, %m_C, %m_D) 411 : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>) 412 -> (index, index) 413 414 %r0 = func.call @memref_dot_simple(%m_A, %m_B, %m_C, %m_D, %M1, %N1) 415 : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>, 416 index, index) -> f64 417 vector.print %r0 : f64 418 // CHECK: 86 419 420 %r1 = func.call @memref_dot_optimized(%m_A, %m_B, %m_C, %m_D, %M1, %N1) 421 : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>, 422 index, index) -> f64 423 vector.print %r1 : f64 424 // CHECK: 86 425 426 %r2 = func.call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M1, %N1) 427 : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>, 428 index, index) -> f64 429 vector.print %r2 : f64 430 // CHECK: 86 431 432 %r6 = func.call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M1, %N1) 433 : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>, 434 index, index) -> f64 435 vector.print %r6 : f64 436 // CHECK: 86 437 438 // --- Test case 2 ---. 439 // M and N must be a multiple of 8 if smaller than 128. 440 // (Because padding kicks in only for out-of-bounds accesses.) 441 %M2, %N2 = func.call @fill_input_2(%m_A, %m_B, %m_C, %m_D) 442 : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>) 443 -> (index, index) 444 445 %r3 = func.call @memref_dot_simple(%m_A, %m_B, %m_C, %m_D, %M2, %N2) 446 : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>, 447 index, index) -> f64 448 vector.print %r3 : f64 449 // CHECK: 111 450 451 %r4 = func.call @memref_dot_optimized(%m_A, %m_B, %m_C, %m_D, %M2, %N2) 452 : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>, 453 index, index) -> f64 454 vector.print %r4 : f64 455 // CHECK: 111 456 457 %r5 = func.call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M2, %N2) 458 : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>, 459 index, index) -> f64 460 vector.print %r5 : f64 461 // CHECK: 111 462 463 %r7 = func.call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M2, %N2) 464 : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>, 465 index, index) -> f64 466 vector.print %r7 : f64 467 // CHECK: 111 468 469 // Release all resources. 470 memref.dealloc %b_A : memref<128xi64> 471 memref.dealloc %b_B : memref<128xf64> 472 memref.dealloc %b_C : memref<128xi64> 473 memref.dealloc %b_D : memref<128xf64> 474 475 %r = arith.constant 0 : i32 476 return %r : i32 477} 478