xref: /llvm-project/clang/test/CodeGen/builtins-nvptx-mma.py (revision dd3c26a045c081620375a878159f536758baba6e)
1# This script generates all variants of wmma builtins, verifies that clang calls
2# correct LLVM intrinsics, and checks that availability of specific builtins is
3# constrained by the correct PTX version and the target GPU variant.
4
5# Dummy test run to avoid lit warnings.
6# RUN: echo "This is not a real test. It's a generator for builtins-nvpts-mma.cu" >/dev/null
7
8from __future__ import print_function
9
10import argparse
11from collections import defaultdict
12from itertools import product
13from string import Template
14
15
16class MMAFrag:
17    def __init__(self, geom, frag, ptx_elt_type):
18        self.geom = geom
19        self.frag = frag
20        self.ptx_type = ptx_elt_type
21
22    def __repr__(self):
23        return "%s:%s:%s" % (self.geom, self.frag, self.ptx_type)
24
25
26class MMAOp:
27    def __init__(self, a, b, c, d, b1op=""):
28        self.a = a
29        self.b = b
30        self.c = c
31        self.d = d
32        self.b1op = b1op
33
34    def __repr__(self):
35        return "{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d)
36
37
38def make_mma_ops(geoms, types_a, types_b, types_c, types_d, b1ops=None):
39    ops = []
40    if b1ops is None:
41        b1ops = [""]
42    for geom, type_a, type_c in product(geoms, types_a, types_c):
43        for type_b, type_d in product(
44            types_b if types_b else [type_a], types_d if types_d else [type_c]
45        ):
46            ops += [
47                MMAOp(
48                    MMAFrag(geom, "a", type_a),
49                    MMAFrag(geom, "b", type_b),
50                    MMAFrag(geom, "c", type_c),
51                    MMAFrag(geom, "d", type_d),
52                    b1op,
53                )
54                for b1op in b1ops
55            ]
56    return ops
57
58
59def make_ldst_ops(geoms, frags, types):
60    return [
61        MMAFrag(geom, frag, ptx_type)
62        for (geom, frag, ptx_type) in product(geoms, frags, types)
63    ]
64
65
66def get_mma_ops():
67    return (
68        make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
69        + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["bf16"], [], ["f32"], [])
70        + make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
71        + make_mma_ops(
72            ["m16n16k16", "m32n8k16", "m8n32k16"],
73            ["f16"],
74            [],
75            ["f16", "f32"],
76            ["f16", "f32"],
77        )
78        + make_mma_ops(
79            ["m16n16k16", "m32n8k16", "m8n32k16"], ["s8", "u8"], [], ["s32"], []
80        )
81        + make_mma_ops(["m8n8k32"], ["s4", "u4"], [], ["s32"], [])
82        + make_mma_ops(
83            ["m8n8k128"], ["b1"], [], ["s32"], [], [".xor.popc", ".and.popc"]
84        )
85    )
86
87
88def get_ldst_ops():
89    # NOTE: fragemts are from the point of view of PTX.
90    # fragment `d` is only for store ops, others for both loads and stores.
91    return (
92        make_ldst_ops(
93            ["m16n16k16", "m32n8k16", "m8n32k16"],
94            ["a", "b"],
95            ["f16", "u8", "s8", "bf16"],
96        )
97        + make_ldst_ops(
98            ["m16n16k16", "m32n8k16", "m8n32k16"], ["c", "d"], ["f16", "f32", "s32"]
99        )
100        + make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4", "u4"])
101        + make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"])
102        + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"])
103        + make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"])
104        +
105        # TF32 m16n16k8 is odd.
106        # For fragment 'C' it uses __mma_*tf32*_m16n16k8_ld_c
107        # but 'D' calls __mma_m16n16k8_st_c_*f32*.
108        make_ldst_ops(["m16n16k8"], ["a", "b", "c"], ["tf32"])
109        + make_ldst_ops(["m16n16k8"], ["d"], ["f32"])
110    )
111
112
113def is_geom_supported(geom):
114    # geometries for FP and ints.
115    if geom in ["m8n32k16", "m32n8k16"]:
116        return ptx_version >= 61
117    # geometries for sub-ints.
118    if geom in ["m8n8k32", "m8n8k128"]:
119        return ptx_version >= 63 and gpu_arch >= 75
120    if geom == "m16n16k16":
121        return ptx_version >= 60
122    if geom in ["m16n16k8", "m8n8k4"]:
123        return ptx_version >= 70 and gpu_arch >= 80
124    assert False  # Unexpected geometry.
125
126
127def is_type_supported(ptx_type):
128    if ptx_type in ["s8", "u8", "s32"]:
129        return ptx_version >= 63 and gpu_arch >= 72
130    if ptx_type in ["s4", "u4", "b1"]:
131        return ptx_version >= 63 and gpu_arch >= 75
132    if ptx_type in ["bf16", "tf32", "f64"]:
133        return ptx_version >= 70 and gpu_arch >= 80
134    return ptx_version >= 60 and gpu_arch >= 70
135
136
137def is_rnd_supported(op):
138    # rnd is only supported for FP64 WMMA
139    return op.a.ptx_type == "f64"
140
141
142def is_mma_variant_supported(op, layout_a, layout_b, satf):
143    if not (is_type_supported(op.a.ptx_type) and is_geom_supported(op.a.geom)):
144        return False
145
146    if satf and not op.a.ptx_type in ["f16", "s8", "u8", "s4", "u4"]:
147        return False
148
149    # sub-integer types require row/col layout.
150    if op.a.ptx_type in ["s4", "u4", "b1"]:
151        return layout_a == "row" and layout_b == "col"
152    return True
153
154
155def is_ldst_variant_supported(frag, layout):
156    if not (is_type_supported(frag.ptx_type) and is_geom_supported(frag.geom)):
157        return False
158    if frag.ptx_type in ["s4", "u4", "b1"]:
159        # sub-integer types require sm_75 and ptx63, row/col layout for a/b.
160        return (
161            (frag.frag == "a" and layout == "row")
162            or (frag.frag == "b" and layout == "col")
163            or frag.frag in ["c", "d"]
164        )
165    return True
166
167
168def get_builtin_prefix(frag):
169    prefix = None
170    if frag.geom in ["m16n16k16", "m32n8k16", "m8n32k16"]:
171        if frag.ptx_type in ["f16", "f32"]:
172            prefix = "__hmma"
173        elif frag.ptx_type == "bf16":
174            prefix = "__mma_bf16"
175        else:
176            prefix = "__imma"
177    elif frag.geom == "m8n8k32":
178        prefix = "__imma"  # sub-integers
179    elif frag.geom == "m8n8k128":
180        prefix = "__bmma"
181    elif frag.geom == "m8n8k4":
182        prefix = "__dmma"
183    elif frag.geom == "m16n16k8":
184        if frag.ptx_type == "f32":
185            prefix = "__mma"
186        else:
187            prefix = "__mma_tf32"
188    assert prefix
189    return prefix
190
191
192def get_ldst_builtin_name(frag):
193    prefix = get_builtin_prefix(frag)
194
195    if prefix == "__hmma":
196        suffix = "" if frag.frag in ["a", "b"] else frag.ptx_type
197    elif prefix in ["__dmma", "__mma_bf16", "__mma_tf32"]:
198        suffix = "" if frag.frag in ["a", "b", "c"] else frag.ptx_type
199    else:
200        suffix = "" if frag.frag == "c" else frag.ptx_type
201        if suffix == "s32":
202            suffix = "i32"
203
204    if frag.frag == "d":
205        ifrag = "c"
206        op = "st"
207    else:
208        ifrag = frag.frag
209        op = "ld"
210
211    name = "%s_%s_%s_%s%s" % (
212        prefix,
213        frag.geom,
214        op,
215        ifrag,
216        "_" + suffix if suffix else "",
217    )
218    return name
219
220
221def get_mma_builtin_name(op):
222    prefix = get_builtin_prefix(op.a)
223
224    if prefix == "__hmma":
225        suffix = op.d.ptx_type + op.c.ptx_type
226    elif prefix in ["__mma_bf16", "__mma_tf32"]:
227        suffix = op.d.ptx_type
228    else:
229        suffix = op.a.ptx_type
230
231    name = "{prefix}_{geom}_mma{b1op}_{suffix}".format(
232        prefix=prefix, geom=op.a.geom, b1op=op.b1op.replace(".", "_"), suffix=suffix
233    )
234    return name
235
236
237def get_required_sm(frag, b1op=""):
238    if frag.ptx_type in ["f64", "bf16", "tf32"]:
239        return 80
240    if frag.ptx_type in ["u4", "s4", "b1"]:
241        if b1op == ".and.popc":
242            return 80
243        return 75
244    if frag.ptx_type in ["s8", "u8"]:
245        return 72
246    if frag.ptx_type == "s32":
247        if frag.geom in ["m8n8k32", "m8n8k128"]:  # s4/u4/b1
248            return 75
249        else:  # s8/u8
250            return 72
251    if frag.ptx_type in ["f16", "f32"]:
252        if frag.geom == "m16n16k8":
253            return 80
254        else:
255            return 70
256    assert False
257
258
259def get_required_ptx(frag, b1op=""):
260    if frag.ptx_type == "b1" and b1op == ".and.popc":
261        return 71
262    if frag.ptx_type in ["f64", "bf16", "tf32"]:
263        return 70
264    if frag.ptx_type in ["f16", "f32"]:
265        if frag.geom == "m16n16k16":
266            return 60
267        if frag.geom == "m16n16k8":
268            return 70
269        return 61
270    return 63
271
272
273def get_src_dst_prefix(frag):
274    if frag.ptx_type == "f32":
275        return "f"
276    if frag.ptx_type == "f64":
277        return "d"
278    if frag.ptx_type == "tf32" and frag.frag in ["c", "d"]:
279        return "f"
280    return ""
281
282
283def gen_wmma_ldst_tests(results):
284    load_template = """
285  // CHECK${check_suffix}: call {{.*}} @${intrinsic}
286  // expected-error-re@+1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}}
287  ${builtin}(${dst}, ${src}, ldm, ${blayout});
288""".rstrip()
289    intrinsic_template = (
290        "llvm.nvvm.wmma.${geom}.${op}.${frag}.${ilayout}.stride.${itype}"
291    )
292
293    for frag, layout in sorted(product(get_ldst_ops(), ["row", "col"]), key=str):
294
295        if not is_ldst_variant_supported(frag, layout):
296            continue
297
298        src_dst_prefix = get_src_dst_prefix(frag)
299
300        min_sm = get_required_sm(frag)
301        min_ptx = get_required_ptx(frag)
302        # TF32 uses f32 for accumulator loads.
303        if frag.geom == "m16n16k8" and frag.frag == "c":
304            assert frag.ptx_type == "tf32"
305            itype = "f32"
306        else:
307            itype = frag.ptx_type
308
309        params = {
310            "check_suffix": "_PTX%d_SM%d" % (min_ptx, min_sm),
311            "builtin": get_ldst_builtin_name(frag),
312            "min_ptx": min_ptx,
313            "min_sm": min_sm,
314            "dst": src_dst_prefix + "dst",
315            "src": src_dst_prefix + "src",
316            "blayout": 0 if layout == "row" else 1,
317            "intrinsic": Template(intrinsic_template).substitute(
318                {
319                    "frag": frag.frag,
320                    "geom": frag.geom,
321                    "ilayout": layout,
322                    "itype": itype,
323                    "op": "store" if frag.frag == "d" else "load",
324                }
325            ),
326        }
327        results[(min_ptx, min_sm)] += Template(load_template).substitute(params)
328
329    return results
330
331
332def mma_signature(op):
333    if op.a.ptx_type == "f16":
334        # FP16 ops identified by accumulator & result type.
335        return "%s.%s" % (op.d.ptx_type, op.c.ptx_type)
336    else:
337        # other ops are identified by input type.
338        return op.a.ptx_type
339
340
341# Get numeric value for rowcol parameter of the builtin
342# AFAICT it uses the encoding accepted by NVVM intrinsics:
343# https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#nvvm-intrin-warp-level-matrix-mma
344def get_ilayout(a, b):
345    return {"row.row": 0, "row.col": 1, "col.row": 2, "col.col": 3}[a + "." + b]
346
347
348def gen_wmma_mma_tests(results):
349    mma_template = """
350  // CHECK${check_suffix}: call {{.*}} @${intrinsic}
351  // expected-error-re@+1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}}
352  ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_satf});
353""".rstrip()
354    intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}.${intrinsic_signature}${satf}"
355
356    for op, alayout, blayout, satf in sorted(
357        product(get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""]),
358        key=str,
359    ):
360
361        if not is_mma_variant_supported(op, alayout, blayout, satf):
362            continue
363
364        asrc_prefix = get_src_dst_prefix(op.a)
365        csrc_prefix = get_src_dst_prefix(op.c)
366        ddst_prefix = get_src_dst_prefix(op.d)
367        if op.a.ptx_type == "b1":  # .b1 MMA has no satf argument.
368            isatf_arg = ""
369        else:
370            isatf_arg = ", 1" if satf else ", 0"
371        min_sm = get_required_sm(op.a, op.b1op)
372        min_ptx = get_required_ptx(op.a, op.b1op)
373        params = {
374            "check_suffix": "_PTX%d_SM%d" % (min_ptx, min_sm),
375            "builtin": get_mma_builtin_name(op),
376            "min_ptx": min_ptx,
377            "min_sm": min_sm,
378            "dst": ddst_prefix + "dst",
379            "asrc": asrc_prefix + "src",
380            "csrc": csrc_prefix + "src",
381            "ilayout": get_ilayout(alayout, blayout),
382            "maybe_satf": isatf_arg,
383            "intrinsic": Template(intrinsic_template).substitute(
384                {
385                    "geom": op.a.geom,
386                    "alayout": alayout,
387                    "blayout": blayout,
388                    "intrinsic_signature": mma_signature(op),
389                    "satf": satf,
390                    "b1op": op.b1op,
391                }
392            ),
393        }
394        results[(min_ptx, min_sm)] += Template(mma_template).substitute(params)
395
396    return results
397
398
399def gen_tests():
400    results = gen_wmma_ldst_tests(defaultdict(str))
401    results = gen_wmma_mma_tests(results)
402
403    run_template = r"""
404//
405// *** DO NOT EDIT ***
406//
407//  This test has been automatically generated by
408//  builtins-nvtx-mma.py --ptx=${ptx} --gpu-arch=${sm}
409//
410// Make sure we can handle all builtins available on sm_${sm} with PTX${ptx}
411// ${run}: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_${sm} \
412// ${run}:            -fcuda-is-device -target-feature +ptx${ptx} \
413// ${run}:            -DPTX=${ptx} -DSM=${sm} \
414// ${run}:            -S -emit-llvm -o - -x cuda %s \
415// ${run}:   | FileCheck -check-prefixes=${check_labels} %s
416// Verify that all builtins have correct constraints.
417// ${run}: %clang_cc1 -triple nvptx-unknown-unknown \
418// ${run}:   -target-cpu sm_60 -target-feature +ptx42 \
419// ${run}:   -DPTX=${ptx} -DSM=${sm} -fcuda-is-device -S -o /dev/null -x cuda \
420// ${run}:   -verify %s
421"""
422
423    def supported_variants(ptx, sm, results):
424        return [(ptx_, sm_) for ptx_, sm_ in results if ptx_ <= ptx and sm_ <= sm]
425
426    print(
427        Template(run_template).substitute(
428            {
429                "run": "RUN",  # To avoid lit misinterpreting the template
430                "ptx": ptx_version,
431                "sm": gpu_arch,
432                "check_labels": ",".join(
433                    [
434                        "CHECK_PTX%d_SM%d" % (ptx_, sm_)
435                        for ptx_, sm_ in supported_variants(
436                            ptx_version, gpu_arch, results
437                        )
438                    ]
439                ),
440            }
441        )
442    )
443
444    print(
445        """
446#if !defined(CUDA_VERSION)
447#define __device__ __attribute__((device))
448#define __global__ __attribute__((global))
449#define __shared__ __attribute__((shared))
450#define __constant__ __attribute__((constant))
451
452typedef unsigned long long uint64_t;
453#endif
454
455// CHECK-LABEL: test_wmma_buitins
456__device__ void test_wmma_buitins(int *src, int *dst,
457                                  float *fsrc, float *fdst,
458                                  double *dsrc, double *ddst, int ldm) {
459"""
460    )
461
462    for (ptx, sm), tests in sorted(results.items()):
463        print()
464        print("#if (PTX >= %d) && (SM >= %d)" % (ptx, sm))
465        print(tests)
466        print("#endif // (PTX >= %d) && (SM >= %d)" % (ptx, sm))
467
468    print("}")
469
470
471parser = argparse.ArgumentParser()
472parser.add_argument("--ptx", type=int, default=60)
473parser.add_argument("--gpu-arch", type=int, default=70)
474args = parser.parse_args()
475ptx_version = args.ptx
476gpu_arch = args.gpu_arch
477
478gen_tests()
479