1`void 2'matmul_name` ('rtype` * const restrict retarray, 3 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas, 4 int blas_limit, blas_call gemm) 5{ 6 const 'rtype_name` * restrict abase; 7 const 'rtype_name` * restrict bbase; 8 'rtype_name` * restrict dest; 9 10 index_type rxstride, rystride, axstride, aystride, bxstride, bystride; 11 index_type x, y, n, count, xcount, ycount; 12 13 assert (GFC_DESCRIPTOR_RANK (a) == 2 14 || GFC_DESCRIPTOR_RANK (b) == 2); 15 16/* C[xcount,ycount] = A[xcount, count] * B[count,ycount] 17 18 Either A or B (but not both) can be rank 1: 19 20 o One-dimensional argument A is implicitly treated as a row matrix 21 dimensioned [1,count], so xcount=1. 22 23 o One-dimensional argument B is implicitly treated as a column matrix 24 dimensioned [count, 1], so ycount=1. 25*/ 26 27 if (retarray->base_addr == NULL) 28 { 29 if (GFC_DESCRIPTOR_RANK (a) == 1) 30 { 31 GFC_DIMENSION_SET(retarray->dim[0], 0, 32 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1); 33 } 34 else if (GFC_DESCRIPTOR_RANK (b) == 1) 35 { 36 GFC_DIMENSION_SET(retarray->dim[0], 0, 37 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1); 38 } 39 else 40 { 41 GFC_DIMENSION_SET(retarray->dim[0], 0, 42 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1); 43 44 GFC_DIMENSION_SET(retarray->dim[1], 0, 45 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 46 GFC_DESCRIPTOR_EXTENT(retarray,0)); 47 } 48 49 retarray->base_addr 50 = xmallocarray (size0 ((array_t *) retarray), sizeof ('rtype_name`)); 51 retarray->offset = 0; 52 } 53 else if (unlikely (compile_options.bounds_check)) 54 { 55 index_type ret_extent, arg_extent; 56 57 if (GFC_DESCRIPTOR_RANK (a) == 1) 58 { 59 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1); 60 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0); 61 if (arg_extent != ret_extent) 62 runtime_error ("Array bound mismatch for dimension 1 of " 63 "array (%ld/%ld) ", 64 (long int) ret_extent, (long int) arg_extent); 65 } 66 else if (GFC_DESCRIPTOR_RANK (b) == 1) 67 { 68 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0); 69 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0); 70 if (arg_extent != ret_extent) 71 runtime_error ("Array bound mismatch for dimension 1 of " 72 "array (%ld/%ld) ", 73 (long int) ret_extent, (long int) arg_extent); 74 } 75 else 76 { 77 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0); 78 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0); 79 if (arg_extent != ret_extent) 80 runtime_error ("Array bound mismatch for dimension 1 of " 81 "array (%ld/%ld) ", 82 (long int) ret_extent, (long int) arg_extent); 83 84 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1); 85 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1); 86 if (arg_extent != ret_extent) 87 runtime_error ("Array bound mismatch for dimension 2 of " 88 "array (%ld/%ld) ", 89 (long int) ret_extent, (long int) arg_extent); 90 } 91 } 92' 93sinclude(`matmul_asm_'rtype_code`.m4')dnl 94` 95 if (GFC_DESCRIPTOR_RANK (retarray) == 1) 96 { 97 /* One-dimensional result may be addressed in the code below 98 either as a row or a column matrix. We want both cases to 99 work. */ 100 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0); 101 } 102 else 103 { 104 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0); 105 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1); 106 } 107 108 109 if (GFC_DESCRIPTOR_RANK (a) == 1) 110 { 111 /* Treat it as a a row matrix A[1,count]. */ 112 axstride = GFC_DESCRIPTOR_STRIDE(a,0); 113 aystride = 1; 114 115 xcount = 1; 116 count = GFC_DESCRIPTOR_EXTENT(a,0); 117 } 118 else 119 { 120 axstride = GFC_DESCRIPTOR_STRIDE(a,0); 121 aystride = GFC_DESCRIPTOR_STRIDE(a,1); 122 123 count = GFC_DESCRIPTOR_EXTENT(a,1); 124 xcount = GFC_DESCRIPTOR_EXTENT(a,0); 125 } 126 127 if (count != GFC_DESCRIPTOR_EXTENT(b,0)) 128 { 129 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0) 130 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic " 131 "in dimension 1: is %ld, should be %ld", 132 (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count); 133 } 134 135 if (GFC_DESCRIPTOR_RANK (b) == 1) 136 { 137 /* Treat it as a column matrix B[count,1] */ 138 bxstride = GFC_DESCRIPTOR_STRIDE(b,0); 139 140 /* bystride should never be used for 1-dimensional b. 141 The value is only used for calculation of the 142 memory by the buffer. */ 143 bystride = 256; 144 ycount = 1; 145 } 146 else 147 { 148 bxstride = GFC_DESCRIPTOR_STRIDE(b,0); 149 bystride = GFC_DESCRIPTOR_STRIDE(b,1); 150 ycount = GFC_DESCRIPTOR_EXTENT(b,1); 151 } 152 153 abase = a->base_addr; 154 bbase = b->base_addr; 155 dest = retarray->base_addr; 156 157 /* Now that everything is set up, we perform the multiplication 158 itself. */ 159 160#define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x))) 161#define min(a,b) ((a) <= (b) ? (a) : (b)) 162#define max(a,b) ((a) >= (b) ? (a) : (b)) 163 164 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1) 165 && (bxstride == 1 || bystride == 1) 166 && (((float) xcount) * ((float) ycount) * ((float) count) 167 > POW3(blas_limit))) 168 { 169 const int m = xcount, n = ycount, k = count, ldc = rystride; 170 const 'rtype_name` one = 1, zero = 0; 171 const int lda = (axstride == 1) ? aystride : axstride, 172 ldb = (bxstride == 1) ? bystride : bxstride; 173 174 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1) 175 { 176 assert (gemm != NULL); 177 const char *transa, *transb; 178 if (try_blas & 2) 179 transa = "C"; 180 else 181 transa = axstride == 1 ? "N" : "T"; 182 183 if (try_blas & 4) 184 transb = "C"; 185 else 186 transb = bxstride == 1 ? "N" : "T"; 187 188 gemm (transa, transb , &m, 189 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest, 190 &ldc, 1, 1); 191 return; 192 } 193 } 194 195 if (rxstride == 1 && axstride == 1 && bxstride == 1 196 && GFC_DESCRIPTOR_RANK (b) != 1) 197 { 198 /* This block of code implements a tuned matmul, derived from 199 Superscalar GEMM-based level 3 BLAS, Beta version 0.1 200 201 Bo Kagstrom and Per Ling 202 Department of Computing Science 203 Umea University 204 S-901 87 Umea, Sweden 205 206 from netlib.org, translated to C, and modified for matmul.m4. */ 207 208 const 'rtype_name` *a, *b; 209 'rtype_name` *c; 210 const index_type m = xcount, n = ycount, k = count; 211 212 /* System generated locals */ 213 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset, 214 i1, i2, i3, i4, i5, i6; 215 216 /* Local variables */ 217 'rtype_name` f11, f12, f21, f22, f31, f32, f41, f42, 218 f13, f14, f23, f24, f33, f34, f43, f44; 219 index_type i, j, l, ii, jj, ll; 220 index_type isec, jsec, lsec, uisec, ujsec, ulsec; 221 'rtype_name` *t1; 222 223 a = abase; 224 b = bbase; 225 c = retarray->base_addr; 226 227 /* Parameter adjustments */ 228 c_dim1 = rystride; 229 c_offset = 1 + c_dim1; 230 c -= c_offset; 231 a_dim1 = aystride; 232 a_offset = 1 + a_dim1; 233 a -= a_offset; 234 b_dim1 = bystride; 235 b_offset = 1 + b_dim1; 236 b -= b_offset; 237 238 /* Empty c first. */ 239 for (j=1; j<=n; j++) 240 for (i=1; i<=m; i++) 241 c[i + j * c_dim1] = ('rtype_name`)0; 242 243 /* Early exit if possible */ 244 if (m == 0 || n == 0 || k == 0) 245 return; 246 247 /* Adjust size of t1 to what is needed. */ 248 index_type t1_dim, a_sz; 249 if (aystride == 1) 250 a_sz = rystride; 251 else 252 a_sz = a_dim1; 253 254 t1_dim = a_sz * 256 + b_dim1; 255 if (t1_dim > 65536) 256 t1_dim = 65536; 257 258 t1 = malloc (t1_dim * sizeof('rtype_name`)); 259 260 /* Start turning the crank. */ 261 i1 = n; 262 for (jj = 1; jj <= i1; jj += 512) 263 { 264 /* Computing MIN */ 265 i2 = 512; 266 i3 = n - jj + 1; 267 jsec = min(i2,i3); 268 ujsec = jsec - jsec % 4; 269 i2 = k; 270 for (ll = 1; ll <= i2; ll += 256) 271 { 272 /* Computing MIN */ 273 i3 = 256; 274 i4 = k - ll + 1; 275 lsec = min(i3,i4); 276 ulsec = lsec - lsec % 2; 277 278 i3 = m; 279 for (ii = 1; ii <= i3; ii += 256) 280 { 281 /* Computing MIN */ 282 i4 = 256; 283 i5 = m - ii + 1; 284 isec = min(i4,i5); 285 uisec = isec - isec % 2; 286 i4 = ll + ulsec - 1; 287 for (l = ll; l <= i4; l += 2) 288 { 289 i5 = ii + uisec - 1; 290 for (i = ii; i <= i5; i += 2) 291 { 292 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] = 293 a[i + l * a_dim1]; 294 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] = 295 a[i + (l + 1) * a_dim1]; 296 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] = 297 a[i + 1 + l * a_dim1]; 298 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] = 299 a[i + 1 + (l + 1) * a_dim1]; 300 } 301 if (uisec < isec) 302 { 303 t1[l - ll + 1 + (isec << 8) - 257] = 304 a[ii + isec - 1 + l * a_dim1]; 305 t1[l - ll + 2 + (isec << 8) - 257] = 306 a[ii + isec - 1 + (l + 1) * a_dim1]; 307 } 308 } 309 if (ulsec < lsec) 310 { 311 i4 = ii + isec - 1; 312 for (i = ii; i<= i4; ++i) 313 { 314 t1[lsec + ((i - ii + 1) << 8) - 257] = 315 a[i + (ll + lsec - 1) * a_dim1]; 316 } 317 } 318 319 uisec = isec - isec % 4; 320 i4 = jj + ujsec - 1; 321 for (j = jj; j <= i4; j += 4) 322 { 323 i5 = ii + uisec - 1; 324 for (i = ii; i <= i5; i += 4) 325 { 326 f11 = c[i + j * c_dim1]; 327 f21 = c[i + 1 + j * c_dim1]; 328 f12 = c[i + (j + 1) * c_dim1]; 329 f22 = c[i + 1 + (j + 1) * c_dim1]; 330 f13 = c[i + (j + 2) * c_dim1]; 331 f23 = c[i + 1 + (j + 2) * c_dim1]; 332 f14 = c[i + (j + 3) * c_dim1]; 333 f24 = c[i + 1 + (j + 3) * c_dim1]; 334 f31 = c[i + 2 + j * c_dim1]; 335 f41 = c[i + 3 + j * c_dim1]; 336 f32 = c[i + 2 + (j + 1) * c_dim1]; 337 f42 = c[i + 3 + (j + 1) * c_dim1]; 338 f33 = c[i + 2 + (j + 2) * c_dim1]; 339 f43 = c[i + 3 + (j + 2) * c_dim1]; 340 f34 = c[i + 2 + (j + 3) * c_dim1]; 341 f44 = c[i + 3 + (j + 3) * c_dim1]; 342 i6 = ll + lsec - 1; 343 for (l = ll; l <= i6; ++l) 344 { 345 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] 346 * b[l + j * b_dim1]; 347 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] 348 * b[l + j * b_dim1]; 349 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] 350 * b[l + (j + 1) * b_dim1]; 351 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] 352 * b[l + (j + 1) * b_dim1]; 353 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] 354 * b[l + (j + 2) * b_dim1]; 355 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] 356 * b[l + (j + 2) * b_dim1]; 357 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] 358 * b[l + (j + 3) * b_dim1]; 359 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] 360 * b[l + (j + 3) * b_dim1]; 361 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257] 362 * b[l + j * b_dim1]; 363 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257] 364 * b[l + j * b_dim1]; 365 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257] 366 * b[l + (j + 1) * b_dim1]; 367 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257] 368 * b[l + (j + 1) * b_dim1]; 369 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257] 370 * b[l + (j + 2) * b_dim1]; 371 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257] 372 * b[l + (j + 2) * b_dim1]; 373 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257] 374 * b[l + (j + 3) * b_dim1]; 375 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257] 376 * b[l + (j + 3) * b_dim1]; 377 } 378 c[i + j * c_dim1] = f11; 379 c[i + 1 + j * c_dim1] = f21; 380 c[i + (j + 1) * c_dim1] = f12; 381 c[i + 1 + (j + 1) * c_dim1] = f22; 382 c[i + (j + 2) * c_dim1] = f13; 383 c[i + 1 + (j + 2) * c_dim1] = f23; 384 c[i + (j + 3) * c_dim1] = f14; 385 c[i + 1 + (j + 3) * c_dim1] = f24; 386 c[i + 2 + j * c_dim1] = f31; 387 c[i + 3 + j * c_dim1] = f41; 388 c[i + 2 + (j + 1) * c_dim1] = f32; 389 c[i + 3 + (j + 1) * c_dim1] = f42; 390 c[i + 2 + (j + 2) * c_dim1] = f33; 391 c[i + 3 + (j + 2) * c_dim1] = f43; 392 c[i + 2 + (j + 3) * c_dim1] = f34; 393 c[i + 3 + (j + 3) * c_dim1] = f44; 394 } 395 if (uisec < isec) 396 { 397 i5 = ii + isec - 1; 398 for (i = ii + uisec; i <= i5; ++i) 399 { 400 f11 = c[i + j * c_dim1]; 401 f12 = c[i + (j + 1) * c_dim1]; 402 f13 = c[i + (j + 2) * c_dim1]; 403 f14 = c[i + (j + 3) * c_dim1]; 404 i6 = ll + lsec - 1; 405 for (l = ll; l <= i6; ++l) 406 { 407 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 408 257] * b[l + j * b_dim1]; 409 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 410 257] * b[l + (j + 1) * b_dim1]; 411 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 412 257] * b[l + (j + 2) * b_dim1]; 413 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 414 257] * b[l + (j + 3) * b_dim1]; 415 } 416 c[i + j * c_dim1] = f11; 417 c[i + (j + 1) * c_dim1] = f12; 418 c[i + (j + 2) * c_dim1] = f13; 419 c[i + (j + 3) * c_dim1] = f14; 420 } 421 } 422 } 423 if (ujsec < jsec) 424 { 425 i4 = jj + jsec - 1; 426 for (j = jj + ujsec; j <= i4; ++j) 427 { 428 i5 = ii + uisec - 1; 429 for (i = ii; i <= i5; i += 4) 430 { 431 f11 = c[i + j * c_dim1]; 432 f21 = c[i + 1 + j * c_dim1]; 433 f31 = c[i + 2 + j * c_dim1]; 434 f41 = c[i + 3 + j * c_dim1]; 435 i6 = ll + lsec - 1; 436 for (l = ll; l <= i6; ++l) 437 { 438 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 439 257] * b[l + j * b_dim1]; 440 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 441 257] * b[l + j * b_dim1]; 442 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 443 257] * b[l + j * b_dim1]; 444 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 445 257] * b[l + j * b_dim1]; 446 } 447 c[i + j * c_dim1] = f11; 448 c[i + 1 + j * c_dim1] = f21; 449 c[i + 2 + j * c_dim1] = f31; 450 c[i + 3 + j * c_dim1] = f41; 451 } 452 i5 = ii + isec - 1; 453 for (i = ii + uisec; i <= i5; ++i) 454 { 455 f11 = c[i + j * c_dim1]; 456 i6 = ll + lsec - 1; 457 for (l = ll; l <= i6; ++l) 458 { 459 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 460 257] * b[l + j * b_dim1]; 461 } 462 c[i + j * c_dim1] = f11; 463 } 464 } 465 } 466 } 467 } 468 } 469 free(t1); 470 return; 471 } 472 else if (rxstride == 1 && aystride == 1 && bxstride == 1) 473 { 474 if (GFC_DESCRIPTOR_RANK (a) != 1) 475 { 476 const 'rtype_name` *restrict abase_x; 477 const 'rtype_name` *restrict bbase_y; 478 'rtype_name` *restrict dest_y; 479 'rtype_name` s; 480 481 for (y = 0; y < ycount; y++) 482 { 483 bbase_y = &bbase[y*bystride]; 484 dest_y = &dest[y*rystride]; 485 for (x = 0; x < xcount; x++) 486 { 487 abase_x = &abase[x*axstride]; 488 s = ('rtype_name`) 0; 489 for (n = 0; n < count; n++) 490 s += abase_x[n] * bbase_y[n]; 491 dest_y[x] = s; 492 } 493 } 494 } 495 else 496 { 497 const 'rtype_name` *restrict bbase_y; 498 'rtype_name` s; 499 500 for (y = 0; y < ycount; y++) 501 { 502 bbase_y = &bbase[y*bystride]; 503 s = ('rtype_name`) 0; 504 for (n = 0; n < count; n++) 505 s += abase[n*axstride] * bbase_y[n]; 506 dest[y*rystride] = s; 507 } 508 } 509 } 510 else if (GFC_DESCRIPTOR_RANK (a) == 1) 511 { 512 const 'rtype_name` *restrict bbase_y; 513 'rtype_name` s; 514 515 for (y = 0; y < ycount; y++) 516 { 517 bbase_y = &bbase[y*bystride]; 518 s = ('rtype_name`) 0; 519 for (n = 0; n < count; n++) 520 s += abase[n*axstride] * bbase_y[n*bxstride]; 521 dest[y*rxstride] = s; 522 } 523 } 524 else if (axstride < aystride) 525 { 526 for (y = 0; y < ycount; y++) 527 for (x = 0; x < xcount; x++) 528 dest[x*rxstride + y*rystride] = ('rtype_name`)0; 529 530 for (y = 0; y < ycount; y++) 531 for (n = 0; n < count; n++) 532 for (x = 0; x < xcount; x++) 533 /* dest[x,y] += a[x,n] * b[n,y] */ 534 dest[x*rxstride + y*rystride] += 535 abase[x*axstride + n*aystride] * 536 bbase[n*bxstride + y*bystride]; 537 } 538 else 539 { 540 const 'rtype_name` *restrict abase_x; 541 const 'rtype_name` *restrict bbase_y; 542 'rtype_name` *restrict dest_y; 543 'rtype_name` s; 544 545 for (y = 0; y < ycount; y++) 546 { 547 bbase_y = &bbase[y*bystride]; 548 dest_y = &dest[y*rystride]; 549 for (x = 0; x < xcount; x++) 550 { 551 abase_x = &abase[x*axstride]; 552 s = ('rtype_name`) 0; 553 for (n = 0; n < count; n++) 554 s += abase_x[n*aystride] * bbase_y[n*bxstride]; 555 dest_y[x*rxstride] = s; 556 } 557 } 558 } 559} 560#undef POW3 561#undef min 562#undef max 563' 564