1# This test generates all variants of wmma intrinsics and verifies that LLVM 2# generates correct instructions for them. This is the test generator only. The 3# test scripts themselves are in wmma-ptx*-sm*.py files. 4 5# RUN: true 6 7from __future__ import print_function 8 9import argparse 10from itertools import product 11from string import Template 12 13class MMAType: 14 def __init__(self, ptx_type): 15 self.ptx_type = ptx_type 16 self.llvm_type = { 17 "f16": "<2 x half>", 18 "f32": "float", 19 "f64": "double", 20 "s32": "i32", 21 "b16": "i32", 22 "s8": "i32", 23 "u8": "i32", 24 "s4": "i32", 25 "u4": "i32", 26 "b1": "i32", 27 "bf16": "i32", 28 "tf32": "i32", 29 }[ptx_type] 30 31 self.ptx_reg_pattern = { 32 "f16": "%r[0-9]+", 33 "f32": "%f[0-9]+", 34 "f64": "%fd[0-9]+", 35 }.get(ptx_type, "%r[0-9]+") 36 37 def __repr__(self): 38 return "%s/%s" % (self.ptx_type, self.llvm_type) 39 40 41class MMAFrag: 42 def __init__(self, geom, frag, ptx_elt_type): 43 self.geom = geom 44 self.frag = frag 45 self.mma_type = MMAType(ptx_elt_type) 46 self.nregs = { 47 # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16 48 "m16n16k16:a:u8": 2, 49 "m16n16k16:a:s8": 2, 50 "m16n16k16:b:u8": 2, 51 "m16n16k16:b:s8": 2, 52 "m16n16k16:c:s32": 8, 53 "m16n16k16:d:s32": 8, 54 "m8n32k16:a:u8": 1, 55 "m8n32k16:a:s8": 1, 56 "m8n32k16:b:u8": 4, 57 "m8n32k16:b:s8": 4, 58 "m8n32k16:c:s32": 8, 59 "m8n32k16:d:s32": 8, 60 "m32n8k16:a:u8": 4, 61 "m32n8k16:a:s8": 4, 62 "m32n8k16:b:u8": 1, 63 "m32n8k16:b:s8": 1, 64 "m32n8k16:c:s32": 8, 65 "m32n8k16:d:s32": 8, 66 "m8n8k16:a:u8": 1, 67 "m8n8k16:a:s8": 1, 68 "m8n8k16:b:u8": 1, 69 "m8n8k16:b:s8": 1, 70 "m8n8k16:c:s32": 2, 71 "m8n8k16:d:s32": 2, 72 "m16n8k16:a:u8": 2, 73 "m16n8k16:a:s8": 2, 74 "m16n8k16:b:u8": 1, 75 "m16n8k16:b:s8": 1, 76 "m16n8k16:c:s32": 4, 77 "m16n8k16:d:s32": 4, 78 "m16n8k32:a:u8": 4, 79 "m16n8k32:a:s8": 4, 80 "m16n8k32:b:u8": 2, 81 "m16n8k32:b:s8": 2, 82 "m16n8k32:c:s32": 4, 83 "m16n8k32:d:s32": 4, 84 # u4/s4 -> s32 @ m8n8k32 (u4/s4) 85 "m8n8k32:a:u4": 1, 86 "m8n8k32:a:s4": 1, 87 "m8n8k32:b:u4": 1, 88 "m8n8k32:b:s4": 1, 89 "m8n8k32:c:s32": 2, 90 "m8n8k32:d:s32": 2, 91 "m16n8k32:a:u4": 2, 92 "m16n8k32:a:s4": 2, 93 "m16n8k32:b:u4": 1, 94 "m16n8k32:b:s4": 1, 95 "m16n8k32:c:s32": 4, 96 "m16n8k32:d:s32": 4, 97 "m16n8k64:a:u4": 4, 98 "m16n8k64:a:s4": 4, 99 "m16n8k64:b:u4": 2, 100 "m16n8k64:b:s4": 2, 101 "m16n8k64:c:s32": 4, 102 "m16n8k64:d:s32": 4, 103 # b1 -> s32 @ m8n8k128(b1) 104 "m8n8k128:a:b1": 1, 105 "m8n8k128:b:b1": 1, 106 "m8n8k128:c:s32": 2, 107 "m8n8k128:d:s32": 2, 108 "m16n8k128:a:b1": 2, 109 "m16n8k128:b:b1": 1, 110 "m16n8k128:c:s32": 4, 111 "m16n8k128:d:s32": 4, 112 "m16n8k256:a:b1": 4, 113 "m16n8k256:b:b1": 2, 114 "m16n8k256:c:s32": 4, 115 "m16n8k256:d:s32": 4, 116 # bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16 117 "m16n16k16:a:bf16": 4, 118 "m16n16k16:b:bf16": 4, 119 "m8n32k16:a:bf16": 2, 120 "m8n32k16:b:bf16": 8, 121 "m32n8k16:a:bf16": 8, 122 "m32n8k16:b:bf16": 2, 123 "m16n8k16:a:bf16": 4, 124 "m16n8k16:b:bf16": 2, 125 "m16n8k16:c:f32": 4, 126 "m16n8k16:d:f32": 4, 127 "m16n8k8:a:bf16": 2, 128 "m16n8k8:b:bf16": 1, 129 "m16n8k8:c:f32": 4, 130 "m16n8k8:d:f32": 4, 131 "m8n8k4:a:f64": 1, 132 "m8n8k4:b:f64": 1, 133 "m8n8k4:c:f64": 2, 134 "m8n8k4:d:f64": 2, 135 # tf32 -> s32 @ m16n16k8 136 "m16n16k8:a:tf32": 4, 137 "m16n16k8:b:tf32": 4, 138 "m16n8k4:a:tf32": 2, 139 "m16n8k4:b:tf32": 1, 140 "m16n8k4:c:f32": 4, 141 "m16n8k4:d:f32": 4, 142 "m16n8k8:a:tf32": 4, 143 "m16n8k8:b:tf32": 2, 144 "m16n8k8:c:f32": 4, 145 "m16n8k8:d:f32": 4, 146 "m8n8k4:a:f16": 2, 147 "m8n8k4:b:f16": 2, 148 "m16n8k8:a:f16": 2, 149 "m16n8k8:b:f16": 1, 150 "m16n8k8:c:f16": 2, 151 "m16n8k8:d:f16": 2, 152 "m16n8k8:c:f32": 4, 153 "m16n8k8:d:f32": 4, 154 "m16n8k16:a:f16": 4, 155 "m16n8k16:b:f16": 2, 156 "m16n8k16:c:f16": 2, 157 "m16n8k16:d:f16": 2, 158 "m16n8k16:c:f32": 4, 159 "m16n8k16:d:f32": 4, 160 # ldmatrix 161 "m8n8:x1:b16": 1, 162 "m8n8:x2:b16": 2, 163 "m8n8:x4:b16": 4, 164 }.get( 165 "%s:%s:%s" % (geom, frag, ptx_elt_type), 166 { 167 # All other FP shape/fragment/type combinations have the same size 168 "a:f16": 8, 169 "b:f16": 8, 170 "c:f16": 4, 171 "d:f16": 4, 172 "c:f32": 8, 173 "d:f32": 8, 174 }.get("%s:%s" % (frag, ptx_elt_type), None), 175 ) 176 assert self.nregs 177 178 def __repr__(self): 179 return "%s:%s:%s%s" % ( 180 self.geom, 181 self.frag, 182 self.mma_type, 183 "" if self.nregs == 1 else ("*%d" % self.nregs), 184 ) 185 186 187class MMAOp: 188 def __init__(self, a, b, c, d): 189 self.a = a 190 self.b = b 191 self.c = c 192 self.d = d 193 194 def __repr__(self): 195 return "{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d) 196 197 198def make_mma_ops(geoms, types_a, types_b, types_c, types_d): 199 ops = [] 200 for geom, type_a, type_c in product(geoms, types_a, types_c): 201 for type_b, type_d in product( 202 types_b if types_b else [type_a], types_d if types_d else [type_c] 203 ): 204 ops.append( 205 MMAOp( 206 MMAFrag(geom, "a", type_a), 207 MMAFrag(geom, "b", type_b), 208 MMAFrag(geom, "c", type_c), 209 MMAFrag(geom, "d", type_d), 210 ) 211 ) 212 return ops 213 214 215def make_ldst_ops(geoms, frags, types): 216 return [ 217 MMAFrag(geom, frag, ptx_type) 218 for (geom, frag, ptx_type) in product(geoms, frags, types) 219 ] 220 221 222def make_ldmatrix_ops(geoms, frags, types): 223 return [ 224 MMAFrag(geom, frag, ptx_type) 225 for (geom, frag, ptx_type) in product(geoms, frags, types) 226 ] 227 228 229def get_wmma_ops(): 230 return ( 231 make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], []) 232 + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["bf16"], [], ["f32"], []) 233 + make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], []) 234 + make_mma_ops( 235 ["m16n16k16", "m32n8k16", "m8n32k16"], 236 ["f16"], 237 [], 238 ["f16", "f32"], 239 ["f16", "f32"], 240 ) 241 + make_mma_ops( 242 ["m16n16k16", "m32n8k16", "m8n32k16"], ["s8", "u8"], [], ["s32"], [] 243 ) 244 + make_mma_ops(["m8n8k32"], ["s4", "u4"], [], ["s32"], []) 245 + make_mma_ops(["m8n8k128"], ["b1"], [], ["s32"], []) 246 ) 247 248 249def get_mma_ops(): 250 return ( 251 make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], []) 252 + make_mma_ops(["m16n8k4", "m16n8k8"], ["tf32"], [], ["f32"], []) 253 + make_mma_ops(["m16n8k16", "m16n8k8"], ["bf16"], [], ["f32"], []) 254 + make_mma_ops( 255 ["m8n8k4", "m16n8k8", "m16n8k16"], 256 ["f16"], 257 [], 258 ["f16", "f32"], 259 ["f16", "f32"], 260 ) 261 + make_mma_ops( 262 ["m8n8k16", "m16n8k16", "m16n8k32"], ["s8", "u8"], ["s8", "u8"], ["s32"], [] 263 ) 264 + make_mma_ops( 265 ["m8n8k32", "m16n8k32", "m16n8k64"], ["s4", "u4"], ["s4", "u4"], ["s32"], [] 266 ) 267 + make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"], ["b1"], [], ["s32"], []) 268 ) 269 270 271def get_ldst_ops(kind): 272 ldst_ops = ( 273 make_ldst_ops( 274 ["m16n16k16", "m32n8k16", "m8n32k16"], 275 ["a", "b"], 276 ["f16", "u8", "s8", "bf16"], 277 ) 278 + make_ldst_ops( 279 ["m16n16k16", "m32n8k16", "m8n32k16"], ["c", "d"], ["f16", "f32", "s32"] 280 ) 281 + make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4", "u4"]) 282 + make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) 283 + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]) 284 + make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) 285 + make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) 286 + make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"]) 287 ) 288 return [x for x in ldst_ops if (x.frag == "d") == (kind == "store")] 289 290 291def get_ldmatrix_ops(): 292 return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) 293 294 295def is_wmma_geom_supported(geom): 296 # geometries for FP and ints. 297 if geom in ["m8n32k16", "m32n8k16"]: 298 return ptx_version >= 61 299 # geometries for sub-ints. 300 if geom in ["m8n8k32", "m8n8k128"]: 301 return ptx_version >= 63 and gpu_arch >= 75 302 if geom == "m16n16k16": 303 return ptx_version >= 60 304 if geom == "m16n8k8": 305 return ptx_version >= 65 306 if geom in ["m16n16k8", "m8n8k4"]: 307 return ptx_version >= 70 308 assert False # Unexpected geometry. 309 310 311def is_mma_geom_supported(geom): 312 # geometries for FP and ints. 313 if geom == "m8n8k4": 314 return ptx_version >= 64 315 if geom in ["m16n8k8", "m8n8k16", "m8n8k32"]: 316 return ptx_version >= 65 317 if geom in [ 318 "m16n8k16", 319 "m16n8k4", 320 "m16n8k32", 321 "m16n8k64", 322 "m8n8k128", 323 "m16n8k128", 324 "m16n8k256", 325 ]: 326 return ptx_version >= 70 327 assert False # Unexpected geometry. 328 329 330def is_ldmatrix_geom_supported(geom): 331 if geom in ["m8n8"]: 332 return ptx_version >= 65 and gpu_arch >= 75 333 assert False # Unexpected geometry. 334 335 336def is_type_supported(ptx_type): 337 if ptx_type in ["s8", "u8", "s32"]: 338 return ptx_version >= 63 and gpu_arch >= 72 339 if ptx_type in ["s4", "u4", "b1"]: 340 return ptx_version >= 63 and gpu_arch >= 75 341 if ptx_type == "b16": 342 return ptx_version >= 65 and gpu_arch >= 75 343 if ptx_type in ["bf16", "tf32", "f64"]: 344 return ptx_version >= 70 345 return ptx_version >= 60 and gpu_arch >= 70 346 347 348def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf): 349 if not ( 350 is_type_supported(op.a.mma_type.ptx_type) and is_wmma_geom_supported(op.a.geom) 351 ): 352 return False 353 354 # rnd is only supported for FP64 WMMA 355 if rnd and op.a.mma_type.ptx_type != "f64": 356 return False 357 358 if satf: 359 # satfinite for floating points was removed in PTX 6.5 360 if op.a.mma_type.ptx_type == "f16" and ptx_version >= 65: 361 return False 362 if not op.a.mma_type.ptx_type in ["f16", "s8", "u8", "s4", "u4"]: 363 return False 364 365 # sub-integer require row/col layout. 366 if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]: 367 return layout_a == "row" and layout_b == "col" 368 return True 369 370 371def is_mma_variant_supported(op, layout_a, layout_b, satf): 372 if not ( 373 is_type_supported(op.a.mma_type.ptx_type) and is_mma_geom_supported(op.a.geom) 374 ): 375 return False 376 377 if satf and not op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]: 378 return False 379 380 # If the type of C is f32 then so must the type of D 381 if ( 382 op.a.geom == "m8n8k4" 383 and op.c.mma_type.ptx_type == "f32" 384 and op.d.mma_type.ptx_type != "f32" 385 ): 386 return False 387 388 # A and B type must be the same. C and D type must be the same 389 if op.a.geom == "m16n8k8" and ( 390 op.a.mma_type.ptx_type != op.b.mma_type.ptx_type 391 or op.c.mma_type.ptx_type != op.d.mma_type.ptx_type 392 ): 393 return False 394 395 # C and D type must be the same 396 if op.a.geom == "m16n8k16" and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type: 397 return False 398 399 # Require row/col layout for all MMA except m8n8k4 on FP16 400 if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"): 401 return layout_a == "row" and layout_b == "col" 402 return True 403 404 405def is_ldst_variant_supported(frag, layout): 406 if not ( 407 is_type_supported(frag.mma_type.ptx_type) and is_wmma_geom_supported(frag.geom) 408 ): 409 return False 410 if frag.mma_type.ptx_type in ["s4", "u4", "b1"]: 411 # sub-integer require sm_75 and ptx63, row/col layout for a/b. 412 return ( 413 (frag.frag == "a" and layout == "row") 414 or (frag.frag == "b" and layout == "col") 415 or frag.frag in ["c", "d"] 416 ) 417 return True 418 419 420def is_ldmatrix_variant_supported(frag): 421 if not ( 422 is_type_supported(frag.mma_type.ptx_type) 423 and is_ldmatrix_geom_supported(frag.geom) 424 ): 425 return False 426 return frag.frag in ["x1", "x2", "x4"] 427 428 429def make_wmma_slice_ty(frag): 430 return [frag.mma_type.llvm_type] * frag.nregs 431 432 433def make_wmma_ld_ret_ty(frag): 434 results = make_wmma_slice_ty(frag) 435 if len(results) == 1: 436 return "%s" % results[0] 437 return "{%s}" % ", ".join(results) 438 439 440# returns address space 441def get_aspace(space): 442 space_map = { 443 ".global": 1, 444 ".shared": 3, 445 ".const": 4, 446 ".local": 5, 447 ".param": 101, 448 "": 0, 449 ".generic": 0, 450 } 451 return space_map[space] 452 453 454def get_pspace(space): 455 return "p%di8" % get_aspace(space) 456 457 458def check_pattern(frag): 459 return "{{%s}}" % ", *".join([frag.mma_type.ptx_reg_pattern] * frag.nregs) 460 461 462def gen_wmma_load_tests(): 463 load_template = """ 464declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args}); 465 466; CHECK-LABEL: .func {{.*}}test_${function}( 467define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) { 468; CHECK: ${instruction} 469; CHECK: {${check_result}} 470; CHECK: [%rd{{[0-9]+}}]${stride_pattern} 471 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args}); 472 ret ${ret_ty} %v0; 473} 474 475; CHECK-LABEL: .func{{.*}}test_${function}_o( 476define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) { 477; CHECK: ${instruction} 478; CHECK: {${check_result}} 479; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern} 480 %src1 = getelementptr i8, i8 ${as}* %src, i32 128; 481 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args}); 482 ret ${ret_ty} %v0; 483} 484""" 485 intrinsic_template = ( 486 "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}" 487 ) 488 instruction_template = ( 489 "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}" 490 ) 491 492 generated_items = [] 493 494 for frag, layout, space, stride in product( 495 get_ldst_ops("load"), 496 ["row", "col"], 497 ["", ".shared", ".global"], 498 ["", ".stride"], 499 ): 500 if not is_ldst_variant_supported(frag, layout): 501 continue 502 503 params = { 504 "abc": frag.frag, 505 "aligned": ".aligned" if ptx_version >= 63 else "", 506 "layout": layout, 507 "space": space, 508 "stride": stride, 509 "itype": frag.mma_type.ptx_type, 510 "pspace": get_pspace(space), 511 "as": "addrspace(%d)" % get_aspace(space), 512 "geom": frag.geom, 513 } 514 515 test_params = params 516 test_params["intrinsic"] = Template(intrinsic_template).substitute(params) 517 test_params["function"] = test_params["intrinsic"].replace(".", "_") 518 test_params["instruction"] = Template(instruction_template).substitute(params) 519 test_params["ret_ty"] = make_wmma_ld_ret_ty(frag) 520 test_params["check_result"] = check_pattern(frag) 521 522 if stride: 523 test_params["extra_args"] = ", i32 %stride" 524 test_params["stride_pattern"] = ", %r{{[0-9]+}}" 525 else: 526 test_params["extra_args"] = "" 527 test_params["stride_pattern"] = "" 528 529 print(Template(load_template).substitute(test_params)) 530 531 generated_items.append((test_params["intrinsic"], test_params["instruction"])) 532 533 return generated_items 534 535 536def make_wmma_slice_args(frag): 537 return ", ".join( 538 [ 539 "%s %%%s%d" % (t, frag.frag, i) 540 for i, t in enumerate(make_wmma_slice_ty(frag)) 541 ] 542 ) 543 544 545def gen_wmma_store_tests(): 546 store_template = """ 547declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args}); 548 549; CHECK-LABEL: .func {{.*}}test_${function}( 550define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) { 551; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}} 552; CHECK: {${check_args}} 553; CHECK: ${stride_pattern} 554 call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args}); 555 ret void 556} 557 558; CHECK-LABEL: .func{{.*}}test_${function}_o( 559define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) { 560; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128] 561; CHECK: ${check_args} 562; CHECK: ${stride_pattern} 563 %src1 = getelementptr i8, i8 ${as}* %src, i32 128; 564 call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args}); 565 ret void 566} 567""" 568 intrinsic_template = ( 569 "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}" 570 ) 571 instruction_template = ( 572 "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}" 573 ) 574 575 generated_items = [] 576 577 for frag, layout, space, stride in product( 578 get_ldst_ops("store"), 579 ["row", "col"], 580 ["", ".shared", ".global"], 581 ["", ".stride"], 582 ): 583 584 if not is_ldst_variant_supported(frag, layout): 585 continue 586 587 params = { 588 "abc": frag.frag, 589 "aligned": ".aligned" if ptx_version >= 63 else "", 590 "layout": layout, 591 "space": space, 592 "stride": stride, 593 "itype": frag.mma_type.ptx_type, 594 "pspace": get_pspace(space), 595 "as": "addrspace(%d)" % get_aspace(space), 596 "geom": frag.geom, 597 } 598 599 test_params = params 600 test_params["intrinsic"] = Template(intrinsic_template).substitute(params) 601 test_params["function"] = test_params["intrinsic"].replace(".", "_") 602 test_params["instruction"] = Template(instruction_template).substitute(params) 603 test_params["ret_ty"] = make_wmma_ld_ret_ty(frag) 604 test_params["check_args"] = check_pattern(frag) 605 if stride: 606 test_params["extra_args"] = ", i32 %stride" 607 test_params["stride_pattern"] = ", %r{{[0-9]+}};" 608 else: 609 test_params["extra_args"] = "" 610 test_params["stride_pattern"] = ";" 611 test_params["args"] = make_wmma_slice_args(frag) 612 613 print(Template(store_template).substitute(test_params)) 614 generated_items.append((test_params["intrinsic"], test_params["instruction"])) 615 616 return generated_items 617 618 619def gen_ldmatrix_tests(): 620 ldmatrix_template = """ 621declare ${ret_ty} @${intrinsic}(i8 ${as}* %src); 622 623; CHECK-LABEL: .func {{.*}}test_${function}( 624define ${ret_ty} @test_${function}(i8 ${as}* %src) { 625; CHECK: ${instruction} 626; CHECK: {${check_result}} 627; CHECK: [%rd{{[0-9]+}}] 628 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src); 629 ret ${ret_ty} %v0; 630} 631 632; CHECK-LABEL: .func{{.*}}test_${function}_o( 633define ${ret_ty} @test_${function}_o(i8 ${as}* %src) { 634; CHECK: ${instruction} 635; CHECK: {${check_result}} 636; CHECK: [%rd{{[0-9]+}}+128] 637 %src1 = getelementptr i8, i8 ${as}* %src, i32 128; 638 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1); 639 ret ${ret_ty} %v0; 640} 641""" 642 intrinsic_template = ( 643 "llvm.nvvm.ldmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}" 644 ) 645 instruction_template = ( 646 "ldmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}" 647 ) 648 649 generated_items = [] 650 651 for frag, space, trans in product( 652 get_ldmatrix_ops(), 653 ["", ".shared"], 654 ["", ".trans"], 655 ): 656 if not is_ldmatrix_variant_supported(frag): 657 continue 658 659 params = { 660 "frag": frag.frag, 661 "space": space, 662 "trans": trans, 663 "itype": frag.mma_type.ptx_type, 664 "pspace": get_pspace(space), 665 "as": "addrspace(%d)" % get_aspace(space), 666 "geom": frag.geom, 667 } 668 669 test_params = params 670 test_params["intrinsic"] = Template(intrinsic_template).substitute(params) 671 test_params["function"] = test_params["intrinsic"].replace(".", "_") 672 test_params["instruction"] = Template(instruction_template).substitute(params) 673 test_params["ret_ty"] = make_wmma_ld_ret_ty(frag) 674 test_params["check_result"] = check_pattern(frag) 675 676 print(Template(ldmatrix_template).substitute(test_params)) 677 678 generated_items.append((test_params["intrinsic"], test_params["instruction"])) 679 680 return generated_items 681 682 683def mma_signature(op): 684 if op.a.mma_type.ptx_type == "f16": 685 # FP16 ops identified by accumulator & result type. 686 return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type) 687 elif op.a.mma_type.ptx_type != op.b.mma_type.ptx_type: 688 # other ops are identified by input types. 689 return "%s.%s" % (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type) 690 else: 691 # if input types are the same, it only appears once. 692 return op.a.mma_type.ptx_type 693 694 695def mma_ptx_signature(op): 696 # Encode all four types as D.A.B.C 697 return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c)) 698 699 700def wmma_signature(op): 701 if op.a.mma_type.ptx_type == "f16": 702 # FP16 ops identified by accumulator & result type. 703 return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type) 704 else: 705 # other ops are identified by input type. 706 return op.a.mma_type.ptx_type 707 708 709def wmma_ptx_signature(op): 710 if op.a.mma_type.ptx_type == "f16": 711 # FP16 instructions use D.C 712 return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type) 713 else: 714 # other instructions encode all four types as D.A.B.C 715 return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c)) 716 717 718def common_mma_test_gen(params, op, intrinsic_template, instruction_template): 719 mma_template = """ 720declare ${ret_ty} @${intrinsic}( 721 ${args}); 722 723; CHECK-LABEL: .func {{.*}}test_${function}( 724define ${ret_ty} @test_${function}( 725 ${args}) { 726; CHECK: ${instruction} 727; CHECK-NEXT: ${check_d} 728; CHECK-NEXT: ${check_a} 729; CHECK-NEXT: ${check_b} 730; CHECK-NEXT: ${check_c} 731 %r = call ${ret_ty} @${intrinsic}( 732 ${args}); 733 ret ${ret_ty} %r; 734} 735""" 736 737 test_params = params 738 test_params["intrinsic"] = Template(intrinsic_template).substitute(params) 739 test_params["function"] = test_params["intrinsic"].replace(".", "_") 740 test_params["instruction"] = Template(instruction_template).substitute(params) 741 test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d) 742 test_params["check_a"] = check_pattern(op.a) 743 test_params["check_b"] = check_pattern(op.b) 744 test_params["check_c"] = check_pattern(op.c) 745 test_params["check_d"] = check_pattern(op.d) 746 args = ",\n ".join(make_wmma_slice_args(frag) for frag in (op.a, op.b, op.c)) 747 test_params["args"] = args 748 print(Template(mma_template).substitute(test_params)) 749 return (test_params["intrinsic"], test_params["instruction"]) 750 751 752def get_b1_ops(ptx_type): 753 if ptx_type != "b1": 754 return [""] 755 if ptx_version >= 71: 756 return [".xor.popc", ".and.popc"] 757 return [".xor.popc"] 758 759 760def gen_wmma_mma_tests(): 761 wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}" 762 wmma_instruction_template = "wmma.mma${b1op}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}" 763 764 generated_items = [] 765 766 for op, alayout, blayout, rnd, satf in product( 767 get_wmma_ops(), 768 ["row", "col"], 769 ["row", "col"], 770 [".rn", ".rz", ".rm", ".rp", ""], 771 [".satfinite", ""], 772 ): 773 774 if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf): 775 continue 776 777 for b1op in get_b1_ops(op.a.mma_type.ptx_type): 778 params = { 779 "aligned": ".aligned" if ptx_version >= 63 else "", 780 "alayout": alayout, 781 "blayout": blayout, 782 "intrinsic_signature": wmma_signature(op), 783 "ptx_signature": wmma_ptx_signature(op), 784 "satf": satf, 785 "rnd": rnd, 786 "geom": op.a.geom, 787 "b1op": b1op, 788 } 789 790 intrinsic_template = wmma_intrinsic_template 791 instruction_template = wmma_instruction_template 792 793 generated_items.append( 794 common_mma_test_gen( 795 params, op, intrinsic_template, instruction_template 796 ) 797 ) 798 799 return generated_items 800 801 802def gen_mma_tests(): 803 mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}" 804 mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}" 805 806 generated_items = [] 807 808 for op, alayout, blayout, satf in product( 809 get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""] 810 ): 811 812 if not is_mma_variant_supported(op, alayout, blayout, satf): 813 continue 814 815 for b1op in get_b1_ops(op.a.mma_type.ptx_type): 816 params = { 817 "aligned": ".aligned" if ptx_version >= 63 else "", 818 "alayout": alayout, 819 "blayout": blayout, 820 "intrinsic_signature": mma_signature(op), 821 "ptx_signature": mma_ptx_signature(op), 822 "satf": satf, 823 "geom": op.a.geom, 824 "b1op": b1op, 825 } 826 827 intrinsic_template = mma_intrinsic_template 828 instruction_template = mma_instruction_template 829 830 generated_items.append( 831 common_mma_test_gen( 832 params, op, intrinsic_template, instruction_template 833 ) 834 ) 835 836 return generated_items 837 838 839# Append complete list of intrinsics and instructions we've generated tests for. 840# Generate set of checks to verify that that we did generate sensible set of 841# tests for the given combination of PTX and SM variants. 842# 843def gen_check_unsupported_ops(items): 844 print( 845 "; Complete list of intrinsics supported by PTX%d on sm_%d" 846 % (ptx_version, gpu_arch) 847 ) 848 print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}") 849 print( 850 """ 851 852; NOEXTGEOM-NOT: {{m8n32|m32n8}} 853; NOINT-NOT: .{{s32|s8}} 854; NOSUBINT-NOT: {{s4|u4|b1}} 855; NOMMA-NOT: .m8n8k4. 856; NOALTFLOAT-NOT: .{{bf16|tf32}} 857; NODOUBLE-NOT: .f64 858; NOLDMATRIX-NOT: ldmatrix.sync.aligned 859 860; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p 861; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p 862; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f32 863; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f16 864; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f16 865; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f32 866 867; PTX60 adds support for m32n8k16/m8n32k16 geometries. 868; EXTGEOM-DAG: m32n8k16.load.{{[ab].*}}.f16.p 869; EXTGEOM-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p 870; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f32 871; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f16 872; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f16 873; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f32 874 875; EXTGEOM-DAG: m8n32k16.load.{{[ab].*}}.f16.p 876; EXTGEOM-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p 877; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f32 878; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f16 879; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f16 880; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f32 881 882; INT-DAG: m16n16k16.load.{{[ab].*}}.s8.p 883; INT-DAG: m8n32k16.load.{{[ab].*}}.s8.p 884; INT-DAG: m32n8k16.load.{{[ab].*}}.s8.p 885; INT-DAG: m16n16k16.load.{{[ab].*}}.u8.p 886; INT-DAG: m8n32k16.load.{{[ab].*}}.u8.p 887; INT-DAG: m32n8k16.load.{{[ab].*}}.u8.p 888; INT-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p 889; INT-DAG: m16n16k16.mma.{{.*}}.u8 890; INT-DAG: m16n16k16.mma.{{.*}}.s8 891; INT-DAG: m8n32k16.mma.{{.*}}.u8 892; INT-DAG: m8n32k16.mma.{{.*}}.s8 893; INT-DAG: m32n8k16.mma.{{.*}}.u8 894; INT-DAG: m32n8k16.mma.{{.*}}.s8 895 896; SUBINT-DAG: m8n8k128.load.{{[ab].*}}.b1.p 897; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.s4.p 898; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.u4.p 899; SUBINT-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p 900; SUBINT-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p 901; SUBINT-DAG: m8n8k32.mma.{{.*}}.u4 902; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4 903; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1 904 905; ALTFLOAT-DAG: m16n16k16.load.{{[ab].*}}.bf16.p 906; ALTFLOAT-DAG: m8n32k16.load.{{[ab].*}}.bf16.p 907; ALTFLOAT-DAG: m32n8k16.load.{{[ab].*}}.bf16.p 908; ALTFLOAT-DAG: m16n16k8.load.{{[ab].*}}.tf32.p 909; ALTFLOAT-DAG: m16n16k16.mma.{{.*}}.bf16 910; ALTFLOAT-DAG: m8n32k16.mma.{{.*}}.bf16 911; ALTFLOAT-DAG: m32n8k16.mma.{{.*}}.bf16 912; ALTFLOAT-DAG: m16n16k8.mma.{{.*}}.tf32 913 914; DOUBLE-DAG: m8n8k4.load.{{[abc].*}}.f64.p 915; DOUBLE-DAG: m8n8k4.store.d.{{.*}}.f64.p 916; DOUBLE-DAG: m8n8k4.mma.{{.*}}.f64 917 918; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32 919; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16 920; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16 921; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32 922 923; PTX65MMA-DAG: mma.m16n8k8.row.col.f16.f16 924; PTX65MMA-DAG: mma.m16n8k8.row.col.f32.f32 925; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.u8 926; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.s8 927; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.u8 928; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.s8 929; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.u4 930; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.s4 931; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4 932; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4 933 934; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.b16 935; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.b16 936; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.b16 937; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.b16 938; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.b16 939; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.b16 940; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.shared.b16 941; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.shared.b16 942; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.shared.b16 943; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 944; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 945; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 946 947; PTX71MMA-DAG: mma.m8n8k4.row.col.f64 948; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32 949; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32 950; PTX71MMA-DAG: mma.m16n8k16.row.col.bf16 951; PTX71MMA-DAG: mma.m16n8k8.row.col.bf16 952; PTX71MMA-DAG: mma.m16n8k16.row.col.f16.f16 953; PTX71MMA-DAG: mma.m16n8k16.row.col.f32.f32 954; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8 955; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8 956; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8 957; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8 958; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8 959; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8 960; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8 961; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8 962; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4 963; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4 964; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4 965; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4 966; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4 967; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4 968; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4 969; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4 970; PTX71MMA-DAG: mma.and.popc.m8n8k128.row.col.b1 971; PTX71MMA-DAG: mma.xor.popc.m8n8k128.row.col.b1 972; PTX71MMA-DAG: mma.and.popc.m16n8k128.row.col.b1 973; PTX71MMA-DAG: mma.xor.popc.m16n8k128.row.col.b1 974; PTX71MMA-DAG: mma.and.popc.m16n8k256.row.col.b1 975; PTX71MMA-DAG: mma.xor.popc.m16n8k256.row.col.b1 976; 977 978""" 979 ) 980 981 print("; INTRINSICS_LIST_BEGIN") 982 for intrinsic, instruction in sorted(items): 983 print("; ", intrinsic, " -> ", instruction, "") 984 print("; INTRINSICS_LIST_END") 985 print("; INTRINSICS: ; INTRINSICS_LIST_END") 986 987 988def gen_tests(): 989 items = gen_wmma_load_tests() 990 items += gen_wmma_store_tests() 991 items += gen_ldmatrix_tests() 992 items += gen_wmma_mma_tests() 993 items += gen_mma_tests() 994 gen_check_unsupported_ops(items) 995 996 997def main(): 998 global ptx_version 999 global gpu_arch 1000 parser = argparse.ArgumentParser() 1001 parser.add_argument("--ptx", type=int, default=60) 1002 parser.add_argument("--gpu-arch", type=int, default=70) 1003 args = parser.parse_args() 1004 1005 ptx_version = args.ptx 1006 gpu_arch = args.gpu_arch 1007 1008 gen_tests() 1009 1010 1011if __name__ == "__main__": 1012 main() 1013