1# This script generates all variants of wmma builtins, verifies that clang calls 2# correct LLVM intrinsics, and checks that availability of specific builtins is 3# constrained by the correct PTX version and the target GPU variant. 4 5# Dummy test run to avoid lit warnings. 6# RUN: echo "This is not a real test. It's a generator for builtins-nvpts-mma.cu" >/dev/null 7 8from __future__ import print_function 9 10import argparse 11from collections import defaultdict 12from itertools import product 13from string import Template 14 15 16class MMAFrag: 17 def __init__(self, geom, frag, ptx_elt_type): 18 self.geom = geom 19 self.frag = frag 20 self.ptx_type = ptx_elt_type 21 22 def __repr__(self): 23 return "%s:%s:%s" % (self.geom, self.frag, self.ptx_type) 24 25 26class MMAOp: 27 def __init__(self, a, b, c, d, b1op=""): 28 self.a = a 29 self.b = b 30 self.c = c 31 self.d = d 32 self.b1op = b1op 33 34 def __repr__(self): 35 return "{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d) 36 37 38def make_mma_ops(geoms, types_a, types_b, types_c, types_d, b1ops=None): 39 ops = [] 40 if b1ops is None: 41 b1ops = [""] 42 for geom, type_a, type_c in product(geoms, types_a, types_c): 43 for type_b, type_d in product( 44 types_b if types_b else [type_a], types_d if types_d else [type_c] 45 ): 46 ops += [ 47 MMAOp( 48 MMAFrag(geom, "a", type_a), 49 MMAFrag(geom, "b", type_b), 50 MMAFrag(geom, "c", type_c), 51 MMAFrag(geom, "d", type_d), 52 b1op, 53 ) 54 for b1op in b1ops 55 ] 56 return ops 57 58 59def make_ldst_ops(geoms, frags, types): 60 return [ 61 MMAFrag(geom, frag, ptx_type) 62 for (geom, frag, ptx_type) in product(geoms, frags, types) 63 ] 64 65 66def get_mma_ops(): 67 return ( 68 make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], []) 69 + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["bf16"], [], ["f32"], []) 70 + make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], []) 71 + make_mma_ops( 72 ["m16n16k16", "m32n8k16", "m8n32k16"], 73 ["f16"], 74 [], 75 ["f16", "f32"], 76 ["f16", "f32"], 77 ) 78 + make_mma_ops( 79 ["m16n16k16", "m32n8k16", "m8n32k16"], ["s8", "u8"], [], ["s32"], [] 80 ) 81 + make_mma_ops(["m8n8k32"], ["s4", "u4"], [], ["s32"], []) 82 + make_mma_ops( 83 ["m8n8k128"], ["b1"], [], ["s32"], [], [".xor.popc", ".and.popc"] 84 ) 85 ) 86 87 88def get_ldst_ops(): 89 # NOTE: fragemts are from the point of view of PTX. 90 # fragment `d` is only for store ops, others for both loads and stores. 91 return ( 92 make_ldst_ops( 93 ["m16n16k16", "m32n8k16", "m8n32k16"], 94 ["a", "b"], 95 ["f16", "u8", "s8", "bf16"], 96 ) 97 + make_ldst_ops( 98 ["m16n16k16", "m32n8k16", "m8n32k16"], ["c", "d"], ["f16", "f32", "s32"] 99 ) 100 + make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4", "u4"]) 101 + make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) 102 + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]) 103 + make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) 104 + 105 # TF32 m16n16k8 is odd. 106 # For fragment 'C' it uses __mma_*tf32*_m16n16k8_ld_c 107 # but 'D' calls __mma_m16n16k8_st_c_*f32*. 108 make_ldst_ops(["m16n16k8"], ["a", "b", "c"], ["tf32"]) 109 + make_ldst_ops(["m16n16k8"], ["d"], ["f32"]) 110 ) 111 112 113def is_geom_supported(geom): 114 # geometries for FP and ints. 115 if geom in ["m8n32k16", "m32n8k16"]: 116 return ptx_version >= 61 117 # geometries for sub-ints. 118 if geom in ["m8n8k32", "m8n8k128"]: 119 return ptx_version >= 63 and gpu_arch >= 75 120 if geom == "m16n16k16": 121 return ptx_version >= 60 122 if geom in ["m16n16k8", "m8n8k4"]: 123 return ptx_version >= 70 and gpu_arch >= 80 124 assert False # Unexpected geometry. 125 126 127def is_type_supported(ptx_type): 128 if ptx_type in ["s8", "u8", "s32"]: 129 return ptx_version >= 63 and gpu_arch >= 72 130 if ptx_type in ["s4", "u4", "b1"]: 131 return ptx_version >= 63 and gpu_arch >= 75 132 if ptx_type in ["bf16", "tf32", "f64"]: 133 return ptx_version >= 70 and gpu_arch >= 80 134 return ptx_version >= 60 and gpu_arch >= 70 135 136 137def is_rnd_supported(op): 138 # rnd is only supported for FP64 WMMA 139 return op.a.ptx_type == "f64" 140 141 142def is_mma_variant_supported(op, layout_a, layout_b, satf): 143 if not (is_type_supported(op.a.ptx_type) and is_geom_supported(op.a.geom)): 144 return False 145 146 if satf and not op.a.ptx_type in ["f16", "s8", "u8", "s4", "u4"]: 147 return False 148 149 # sub-integer types require row/col layout. 150 if op.a.ptx_type in ["s4", "u4", "b1"]: 151 return layout_a == "row" and layout_b == "col" 152 return True 153 154 155def is_ldst_variant_supported(frag, layout): 156 if not (is_type_supported(frag.ptx_type) and is_geom_supported(frag.geom)): 157 return False 158 if frag.ptx_type in ["s4", "u4", "b1"]: 159 # sub-integer types require sm_75 and ptx63, row/col layout for a/b. 160 return ( 161 (frag.frag == "a" and layout == "row") 162 or (frag.frag == "b" and layout == "col") 163 or frag.frag in ["c", "d"] 164 ) 165 return True 166 167 168def get_builtin_prefix(frag): 169 prefix = None 170 if frag.geom in ["m16n16k16", "m32n8k16", "m8n32k16"]: 171 if frag.ptx_type in ["f16", "f32"]: 172 prefix = "__hmma" 173 elif frag.ptx_type == "bf16": 174 prefix = "__mma_bf16" 175 else: 176 prefix = "__imma" 177 elif frag.geom == "m8n8k32": 178 prefix = "__imma" # sub-integers 179 elif frag.geom == "m8n8k128": 180 prefix = "__bmma" 181 elif frag.geom == "m8n8k4": 182 prefix = "__dmma" 183 elif frag.geom == "m16n16k8": 184 if frag.ptx_type == "f32": 185 prefix = "__mma" 186 else: 187 prefix = "__mma_tf32" 188 assert prefix 189 return prefix 190 191 192def get_ldst_builtin_name(frag): 193 prefix = get_builtin_prefix(frag) 194 195 if prefix == "__hmma": 196 suffix = "" if frag.frag in ["a", "b"] else frag.ptx_type 197 elif prefix in ["__dmma", "__mma_bf16", "__mma_tf32"]: 198 suffix = "" if frag.frag in ["a", "b", "c"] else frag.ptx_type 199 else: 200 suffix = "" if frag.frag == "c" else frag.ptx_type 201 if suffix == "s32": 202 suffix = "i32" 203 204 if frag.frag == "d": 205 ifrag = "c" 206 op = "st" 207 else: 208 ifrag = frag.frag 209 op = "ld" 210 211 name = "%s_%s_%s_%s%s" % ( 212 prefix, 213 frag.geom, 214 op, 215 ifrag, 216 "_" + suffix if suffix else "", 217 ) 218 return name 219 220 221def get_mma_builtin_name(op): 222 prefix = get_builtin_prefix(op.a) 223 224 if prefix == "__hmma": 225 suffix = op.d.ptx_type + op.c.ptx_type 226 elif prefix in ["__mma_bf16", "__mma_tf32"]: 227 suffix = op.d.ptx_type 228 else: 229 suffix = op.a.ptx_type 230 231 name = "{prefix}_{geom}_mma{b1op}_{suffix}".format( 232 prefix=prefix, geom=op.a.geom, b1op=op.b1op.replace(".", "_"), suffix=suffix 233 ) 234 return name 235 236 237def get_required_sm(frag, b1op=""): 238 if frag.ptx_type in ["f64", "bf16", "tf32"]: 239 return 80 240 if frag.ptx_type in ["u4", "s4", "b1"]: 241 if b1op == ".and.popc": 242 return 80 243 return 75 244 if frag.ptx_type in ["s8", "u8"]: 245 return 72 246 if frag.ptx_type == "s32": 247 if frag.geom in ["m8n8k32", "m8n8k128"]: # s4/u4/b1 248 return 75 249 else: # s8/u8 250 return 72 251 if frag.ptx_type in ["f16", "f32"]: 252 if frag.geom == "m16n16k8": 253 return 80 254 else: 255 return 70 256 assert False 257 258 259def get_required_ptx(frag, b1op=""): 260 if frag.ptx_type == "b1" and b1op == ".and.popc": 261 return 71 262 if frag.ptx_type in ["f64", "bf16", "tf32"]: 263 return 70 264 if frag.ptx_type in ["f16", "f32"]: 265 if frag.geom == "m16n16k16": 266 return 60 267 if frag.geom == "m16n16k8": 268 return 70 269 return 61 270 return 63 271 272 273def get_src_dst_prefix(frag): 274 if frag.ptx_type == "f32": 275 return "f" 276 if frag.ptx_type == "f64": 277 return "d" 278 if frag.ptx_type == "tf32" and frag.frag in ["c", "d"]: 279 return "f" 280 return "" 281 282 283def gen_wmma_ldst_tests(results): 284 load_template = """ 285 // CHECK${check_suffix}: call {{.*}} @${intrinsic} 286 // expected-error-re@+1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}} 287 ${builtin}(${dst}, ${src}, ldm, ${blayout}); 288""".rstrip() 289 intrinsic_template = ( 290 "llvm.nvvm.wmma.${geom}.${op}.${frag}.${ilayout}.stride.${itype}" 291 ) 292 293 for frag, layout in sorted(product(get_ldst_ops(), ["row", "col"]), key=str): 294 295 if not is_ldst_variant_supported(frag, layout): 296 continue 297 298 src_dst_prefix = get_src_dst_prefix(frag) 299 300 min_sm = get_required_sm(frag) 301 min_ptx = get_required_ptx(frag) 302 # TF32 uses f32 for accumulator loads. 303 if frag.geom == "m16n16k8" and frag.frag == "c": 304 assert frag.ptx_type == "tf32" 305 itype = "f32" 306 else: 307 itype = frag.ptx_type 308 309 params = { 310 "check_suffix": "_PTX%d_SM%d" % (min_ptx, min_sm), 311 "builtin": get_ldst_builtin_name(frag), 312 "min_ptx": min_ptx, 313 "min_sm": min_sm, 314 "dst": src_dst_prefix + "dst", 315 "src": src_dst_prefix + "src", 316 "blayout": 0 if layout == "row" else 1, 317 "intrinsic": Template(intrinsic_template).substitute( 318 { 319 "frag": frag.frag, 320 "geom": frag.geom, 321 "ilayout": layout, 322 "itype": itype, 323 "op": "store" if frag.frag == "d" else "load", 324 } 325 ), 326 } 327 results[(min_ptx, min_sm)] += Template(load_template).substitute(params) 328 329 return results 330 331 332def mma_signature(op): 333 if op.a.ptx_type == "f16": 334 # FP16 ops identified by accumulator & result type. 335 return "%s.%s" % (op.d.ptx_type, op.c.ptx_type) 336 else: 337 # other ops are identified by input type. 338 return op.a.ptx_type 339 340 341# Get numeric value for rowcol parameter of the builtin 342# AFAICT it uses the encoding accepted by NVVM intrinsics: 343# https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#nvvm-intrin-warp-level-matrix-mma 344def get_ilayout(a, b): 345 return {"row.row": 0, "row.col": 1, "col.row": 2, "col.col": 3}[a + "." + b] 346 347 348def gen_wmma_mma_tests(results): 349 mma_template = """ 350 // CHECK${check_suffix}: call {{.*}} @${intrinsic} 351 // expected-error-re@+1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}} 352 ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_satf}); 353""".rstrip() 354 intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}.${intrinsic_signature}${satf}" 355 356 for op, alayout, blayout, satf in sorted( 357 product(get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""]), 358 key=str, 359 ): 360 361 if not is_mma_variant_supported(op, alayout, blayout, satf): 362 continue 363 364 asrc_prefix = get_src_dst_prefix(op.a) 365 csrc_prefix = get_src_dst_prefix(op.c) 366 ddst_prefix = get_src_dst_prefix(op.d) 367 if op.a.ptx_type == "b1": # .b1 MMA has no satf argument. 368 isatf_arg = "" 369 else: 370 isatf_arg = ", 1" if satf else ", 0" 371 min_sm = get_required_sm(op.a, op.b1op) 372 min_ptx = get_required_ptx(op.a, op.b1op) 373 params = { 374 "check_suffix": "_PTX%d_SM%d" % (min_ptx, min_sm), 375 "builtin": get_mma_builtin_name(op), 376 "min_ptx": min_ptx, 377 "min_sm": min_sm, 378 "dst": ddst_prefix + "dst", 379 "asrc": asrc_prefix + "src", 380 "csrc": csrc_prefix + "src", 381 "ilayout": get_ilayout(alayout, blayout), 382 "maybe_satf": isatf_arg, 383 "intrinsic": Template(intrinsic_template).substitute( 384 { 385 "geom": op.a.geom, 386 "alayout": alayout, 387 "blayout": blayout, 388 "intrinsic_signature": mma_signature(op), 389 "satf": satf, 390 "b1op": op.b1op, 391 } 392 ), 393 } 394 results[(min_ptx, min_sm)] += Template(mma_template).substitute(params) 395 396 return results 397 398 399def gen_tests(): 400 results = gen_wmma_ldst_tests(defaultdict(str)) 401 results = gen_wmma_mma_tests(results) 402 403 run_template = r""" 404// 405// *** DO NOT EDIT *** 406// 407// This test has been automatically generated by 408// builtins-nvtx-mma.py --ptx=${ptx} --gpu-arch=${sm} 409// 410// Make sure we can handle all builtins available on sm_${sm} with PTX${ptx} 411// ${run}: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_${sm} \ 412// ${run}: -fcuda-is-device -target-feature +ptx${ptx} \ 413// ${run}: -DPTX=${ptx} -DSM=${sm} \ 414// ${run}: -S -emit-llvm -o - -x cuda %s \ 415// ${run}: | FileCheck -check-prefixes=${check_labels} %s 416// Verify that all builtins have correct constraints. 417// ${run}: %clang_cc1 -triple nvptx-unknown-unknown \ 418// ${run}: -target-cpu sm_60 -target-feature +ptx42 \ 419// ${run}: -DPTX=${ptx} -DSM=${sm} -fcuda-is-device -S -o /dev/null -x cuda \ 420// ${run}: -verify %s 421""" 422 423 def supported_variants(ptx, sm, results): 424 return [(ptx_, sm_) for ptx_, sm_ in results if ptx_ <= ptx and sm_ <= sm] 425 426 print( 427 Template(run_template).substitute( 428 { 429 "run": "RUN", # To avoid lit misinterpreting the template 430 "ptx": ptx_version, 431 "sm": gpu_arch, 432 "check_labels": ",".join( 433 [ 434 "CHECK_PTX%d_SM%d" % (ptx_, sm_) 435 for ptx_, sm_ in supported_variants( 436 ptx_version, gpu_arch, results 437 ) 438 ] 439 ), 440 } 441 ) 442 ) 443 444 print( 445 """ 446#if !defined(CUDA_VERSION) 447#define __device__ __attribute__((device)) 448#define __global__ __attribute__((global)) 449#define __shared__ __attribute__((shared)) 450#define __constant__ __attribute__((constant)) 451 452typedef unsigned long long uint64_t; 453#endif 454 455// CHECK-LABEL: test_wmma_buitins 456__device__ void test_wmma_buitins(int *src, int *dst, 457 float *fsrc, float *fdst, 458 double *dsrc, double *ddst, int ldm) { 459""" 460 ) 461 462 for (ptx, sm), tests in sorted(results.items()): 463 print() 464 print("#if (PTX >= %d) && (SM >= %d)" % (ptx, sm)) 465 print(tests) 466 print("#endif // (PTX >= %d) && (SM >= %d)" % (ptx, sm)) 467 468 print("}") 469 470 471parser = argparse.ArgumentParser() 472parser.add_argument("--ptx", type=int, default=60) 473parser.add_argument("--gpu-arch", type=int, default=70) 474args = parser.parse_args() 475ptx_version = args.ptx 476gpu_arch = args.gpu_arch 477 478gen_tests() 479