xref: /llvm-project/llvm/test/CodeGen/NVPTX/wmma.py (revision 4f33331317d1f14b24f0f5544d8adb85e5d54616)
1# This test generates all variants of wmma intrinsics and verifies that LLVM
2# generates correct instructions for them. This is the test generator only.  The
3# test scripts themselves are in wmma-ptx*-sm*.py files.
4
5# RUN: true
6
7from __future__ import print_function
8
9import argparse
10from itertools import product
11from string import Template
12
13class MMAType:
14    def __init__(self, ptx_type):
15        self.ptx_type = ptx_type
16        self.llvm_type = {
17            "f16": "<2 x half>",
18            "f32": "float",
19            "f64": "double",
20            "s32": "i32",
21            "b16": "i32",
22            "s8": "i32",
23            "u8": "i32",
24            "s4": "i32",
25            "u4": "i32",
26            "b1": "i32",
27            "bf16": "i32",
28            "tf32": "i32",
29        }[ptx_type]
30
31        self.ptx_reg_pattern = {
32            "f16": "%r[0-9]+",
33            "f32": "%f[0-9]+",
34            "f64": "%fd[0-9]+",
35        }.get(ptx_type, "%r[0-9]+")
36
37    def __repr__(self):
38        return "%s/%s" % (self.ptx_type, self.llvm_type)
39
40
41class MMAFrag:
42    def __init__(self, geom, frag, ptx_elt_type):
43        self.geom = geom
44        self.frag = frag
45        self.mma_type = MMAType(ptx_elt_type)
46        self.nregs = {
47            # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
48            "m16n16k16:a:u8": 2,
49            "m16n16k16:a:s8": 2,
50            "m16n16k16:b:u8": 2,
51            "m16n16k16:b:s8": 2,
52            "m16n16k16:c:s32": 8,
53            "m16n16k16:d:s32": 8,
54            "m8n32k16:a:u8": 1,
55            "m8n32k16:a:s8": 1,
56            "m8n32k16:b:u8": 4,
57            "m8n32k16:b:s8": 4,
58            "m8n32k16:c:s32": 8,
59            "m8n32k16:d:s32": 8,
60            "m32n8k16:a:u8": 4,
61            "m32n8k16:a:s8": 4,
62            "m32n8k16:b:u8": 1,
63            "m32n8k16:b:s8": 1,
64            "m32n8k16:c:s32": 8,
65            "m32n8k16:d:s32": 8,
66            "m8n8k16:a:u8": 1,
67            "m8n8k16:a:s8": 1,
68            "m8n8k16:b:u8": 1,
69            "m8n8k16:b:s8": 1,
70            "m8n8k16:c:s32": 2,
71            "m8n8k16:d:s32": 2,
72            "m16n8k16:a:u8": 2,
73            "m16n8k16:a:s8": 2,
74            "m16n8k16:b:u8": 1,
75            "m16n8k16:b:s8": 1,
76            "m16n8k16:c:s32": 4,
77            "m16n8k16:d:s32": 4,
78            "m16n8k32:a:u8": 4,
79            "m16n8k32:a:s8": 4,
80            "m16n8k32:b:u8": 2,
81            "m16n8k32:b:s8": 2,
82            "m16n8k32:c:s32": 4,
83            "m16n8k32:d:s32": 4,
84            # u4/s4 -> s32 @ m8n8k32 (u4/s4)
85            "m8n8k32:a:u4": 1,
86            "m8n8k32:a:s4": 1,
87            "m8n8k32:b:u4": 1,
88            "m8n8k32:b:s4": 1,
89            "m8n8k32:c:s32": 2,
90            "m8n8k32:d:s32": 2,
91            "m16n8k32:a:u4": 2,
92            "m16n8k32:a:s4": 2,
93            "m16n8k32:b:u4": 1,
94            "m16n8k32:b:s4": 1,
95            "m16n8k32:c:s32": 4,
96            "m16n8k32:d:s32": 4,
97            "m16n8k64:a:u4": 4,
98            "m16n8k64:a:s4": 4,
99            "m16n8k64:b:u4": 2,
100            "m16n8k64:b:s4": 2,
101            "m16n8k64:c:s32": 4,
102            "m16n8k64:d:s32": 4,
103            # b1 -> s32 @ m8n8k128(b1)
104            "m8n8k128:a:b1": 1,
105            "m8n8k128:b:b1": 1,
106            "m8n8k128:c:s32": 2,
107            "m8n8k128:d:s32": 2,
108            "m16n8k128:a:b1": 2,
109            "m16n8k128:b:b1": 1,
110            "m16n8k128:c:s32": 4,
111            "m16n8k128:d:s32": 4,
112            "m16n8k256:a:b1": 4,
113            "m16n8k256:b:b1": 2,
114            "m16n8k256:c:s32": 4,
115            "m16n8k256:d:s32": 4,
116            # bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
117            "m16n16k16:a:bf16": 4,
118            "m16n16k16:b:bf16": 4,
119            "m8n32k16:a:bf16": 2,
120            "m8n32k16:b:bf16": 8,
121            "m32n8k16:a:bf16": 8,
122            "m32n8k16:b:bf16": 2,
123            "m16n8k16:a:bf16": 4,
124            "m16n8k16:b:bf16": 2,
125            "m16n8k16:c:f32": 4,
126            "m16n8k16:d:f32": 4,
127            "m16n8k8:a:bf16": 2,
128            "m16n8k8:b:bf16": 1,
129            "m16n8k8:c:f32": 4,
130            "m16n8k8:d:f32": 4,
131            "m8n8k4:a:f64": 1,
132            "m8n8k4:b:f64": 1,
133            "m8n8k4:c:f64": 2,
134            "m8n8k4:d:f64": 2,
135            # tf32 -> s32 @ m16n16k8
136            "m16n16k8:a:tf32": 4,
137            "m16n16k8:b:tf32": 4,
138            "m16n8k4:a:tf32": 2,
139            "m16n8k4:b:tf32": 1,
140            "m16n8k4:c:f32": 4,
141            "m16n8k4:d:f32": 4,
142            "m16n8k8:a:tf32": 4,
143            "m16n8k8:b:tf32": 2,
144            "m16n8k8:c:f32": 4,
145            "m16n8k8:d:f32": 4,
146            "m8n8k4:a:f16": 2,
147            "m8n8k4:b:f16": 2,
148            "m16n8k8:a:f16": 2,
149            "m16n8k8:b:f16": 1,
150            "m16n8k8:c:f16": 2,
151            "m16n8k8:d:f16": 2,
152            "m16n8k8:c:f32": 4,
153            "m16n8k8:d:f32": 4,
154            "m16n8k16:a:f16": 4,
155            "m16n8k16:b:f16": 2,
156            "m16n8k16:c:f16": 2,
157            "m16n8k16:d:f16": 2,
158            "m16n8k16:c:f32": 4,
159            "m16n8k16:d:f32": 4,
160            # ldmatrix
161            "m8n8:x1:b16": 1,
162            "m8n8:x2:b16": 2,
163            "m8n8:x4:b16": 4,
164        }.get(
165            "%s:%s:%s" % (geom, frag, ptx_elt_type),
166            {
167                # All other FP shape/fragment/type combinations have the same size
168                "a:f16": 8,
169                "b:f16": 8,
170                "c:f16": 4,
171                "d:f16": 4,
172                "c:f32": 8,
173                "d:f32": 8,
174            }.get("%s:%s" % (frag, ptx_elt_type), None),
175        )
176        assert self.nregs
177
178    def __repr__(self):
179        return "%s:%s:%s%s" % (
180            self.geom,
181            self.frag,
182            self.mma_type,
183            "" if self.nregs == 1 else ("*%d" % self.nregs),
184        )
185
186
187class MMAOp:
188    def __init__(self, a, b, c, d):
189        self.a = a
190        self.b = b
191        self.c = c
192        self.d = d
193
194    def __repr__(self):
195        return "{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d)
196
197
198def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
199    ops = []
200    for geom, type_a, type_c in product(geoms, types_a, types_c):
201        for type_b, type_d in product(
202            types_b if types_b else [type_a], types_d if types_d else [type_c]
203        ):
204            ops.append(
205                MMAOp(
206                    MMAFrag(geom, "a", type_a),
207                    MMAFrag(geom, "b", type_b),
208                    MMAFrag(geom, "c", type_c),
209                    MMAFrag(geom, "d", type_d),
210                )
211            )
212    return ops
213
214
215def make_ldst_ops(geoms, frags, types):
216    return [
217        MMAFrag(geom, frag, ptx_type)
218        for (geom, frag, ptx_type) in product(geoms, frags, types)
219    ]
220
221
222def make_ldmatrix_ops(geoms, frags, types):
223    return [
224        MMAFrag(geom, frag, ptx_type)
225        for (geom, frag, ptx_type) in product(geoms, frags, types)
226    ]
227
228
229def get_wmma_ops():
230    return (
231        make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
232        + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["bf16"], [], ["f32"], [])
233        + make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
234        + make_mma_ops(
235            ["m16n16k16", "m32n8k16", "m8n32k16"],
236            ["f16"],
237            [],
238            ["f16", "f32"],
239            ["f16", "f32"],
240        )
241        + make_mma_ops(
242            ["m16n16k16", "m32n8k16", "m8n32k16"], ["s8", "u8"], [], ["s32"], []
243        )
244        + make_mma_ops(["m8n8k32"], ["s4", "u4"], [], ["s32"], [])
245        + make_mma_ops(["m8n8k128"], ["b1"], [], ["s32"], [])
246    )
247
248
249def get_mma_ops():
250    return (
251        make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
252        + make_mma_ops(["m16n8k4", "m16n8k8"], ["tf32"], [], ["f32"], [])
253        + make_mma_ops(["m16n8k16", "m16n8k8"], ["bf16"], [], ["f32"], [])
254        + make_mma_ops(
255            ["m8n8k4", "m16n8k8", "m16n8k16"],
256            ["f16"],
257            [],
258            ["f16", "f32"],
259            ["f16", "f32"],
260        )
261        + make_mma_ops(
262            ["m8n8k16", "m16n8k16", "m16n8k32"], ["s8", "u8"], ["s8", "u8"], ["s32"], []
263        )
264        + make_mma_ops(
265            ["m8n8k32", "m16n8k32", "m16n8k64"], ["s4", "u4"], ["s4", "u4"], ["s32"], []
266        )
267        + make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"], ["b1"], [], ["s32"], [])
268    )
269
270
271def get_ldst_ops(kind):
272    ldst_ops = (
273        make_ldst_ops(
274            ["m16n16k16", "m32n8k16", "m8n32k16"],
275            ["a", "b"],
276            ["f16", "u8", "s8", "bf16"],
277        )
278        + make_ldst_ops(
279            ["m16n16k16", "m32n8k16", "m8n32k16"], ["c", "d"], ["f16", "f32", "s32"]
280        )
281        + make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4", "u4"])
282        + make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"])
283        + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"])
284        + make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"])
285        + make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"])
286        + make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"])
287    )
288    return [x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
289
290
291def get_ldmatrix_ops():
292    return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
293
294
295def is_wmma_geom_supported(geom):
296    # geometries for FP and ints.
297    if geom in ["m8n32k16", "m32n8k16"]:
298        return ptx_version >= 61
299    # geometries for sub-ints.
300    if geom in ["m8n8k32", "m8n8k128"]:
301        return ptx_version >= 63 and gpu_arch >= 75
302    if geom == "m16n16k16":
303        return ptx_version >= 60
304    if geom == "m16n8k8":
305        return ptx_version >= 65
306    if geom in ["m16n16k8", "m8n8k4"]:
307        return ptx_version >= 70
308    assert False  # Unexpected geometry.
309
310
311def is_mma_geom_supported(geom):
312    # geometries for FP and ints.
313    if geom == "m8n8k4":
314        return ptx_version >= 64
315    if geom in ["m16n8k8", "m8n8k16", "m8n8k32"]:
316        return ptx_version >= 65
317    if geom in [
318        "m16n8k16",
319        "m16n8k4",
320        "m16n8k32",
321        "m16n8k64",
322        "m8n8k128",
323        "m16n8k128",
324        "m16n8k256",
325    ]:
326        return ptx_version >= 70
327    assert False  # Unexpected geometry.
328
329
330def is_ldmatrix_geom_supported(geom):
331    if geom in ["m8n8"]:
332        return ptx_version >= 65 and gpu_arch >= 75
333    assert False  # Unexpected geometry.
334
335
336def is_type_supported(ptx_type):
337    if ptx_type in ["s8", "u8", "s32"]:
338        return ptx_version >= 63 and gpu_arch >= 72
339    if ptx_type in ["s4", "u4", "b1"]:
340        return ptx_version >= 63 and gpu_arch >= 75
341    if ptx_type == "b16":
342        return ptx_version >= 65 and gpu_arch >= 75
343    if ptx_type in ["bf16", "tf32", "f64"]:
344        return ptx_version >= 70
345    return ptx_version >= 60 and gpu_arch >= 70
346
347
348def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf):
349    if not (
350        is_type_supported(op.a.mma_type.ptx_type) and is_wmma_geom_supported(op.a.geom)
351    ):
352        return False
353
354    # rnd is only supported for FP64 WMMA
355    if rnd and op.a.mma_type.ptx_type != "f64":
356        return False
357
358    if satf:
359        # satfinite for floating points was removed in PTX 6.5
360        if op.a.mma_type.ptx_type == "f16" and ptx_version >= 65:
361            return False
362        if not op.a.mma_type.ptx_type in ["f16", "s8", "u8", "s4", "u4"]:
363            return False
364
365    # sub-integer require row/col layout.
366    if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
367        return layout_a == "row" and layout_b == "col"
368    return True
369
370
371def is_mma_variant_supported(op, layout_a, layout_b, satf):
372    if not (
373        is_type_supported(op.a.mma_type.ptx_type) and is_mma_geom_supported(op.a.geom)
374    ):
375        return False
376
377    if satf and not op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]:
378        return False
379
380    # If the type of C is f32 then so must the type of D
381    if (
382        op.a.geom == "m8n8k4"
383        and op.c.mma_type.ptx_type == "f32"
384        and op.d.mma_type.ptx_type != "f32"
385    ):
386        return False
387
388    # A and B type must be the same. C and D type must be the same
389    if op.a.geom == "m16n8k8" and (
390        op.a.mma_type.ptx_type != op.b.mma_type.ptx_type
391        or op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
392    ):
393        return False
394
395    # C and D type must be the same
396    if op.a.geom == "m16n8k16" and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type:
397        return False
398
399    # Require row/col layout for all MMA except m8n8k4 on FP16
400    if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
401        return layout_a == "row" and layout_b == "col"
402    return True
403
404
405def is_ldst_variant_supported(frag, layout):
406    if not (
407        is_type_supported(frag.mma_type.ptx_type) and is_wmma_geom_supported(frag.geom)
408    ):
409        return False
410    if frag.mma_type.ptx_type in ["s4", "u4", "b1"]:
411        # sub-integer require sm_75 and ptx63, row/col layout for a/b.
412        return (
413            (frag.frag == "a" and layout == "row")
414            or (frag.frag == "b" and layout == "col")
415            or frag.frag in ["c", "d"]
416        )
417    return True
418
419
420def is_ldmatrix_variant_supported(frag):
421    if not (
422        is_type_supported(frag.mma_type.ptx_type)
423        and is_ldmatrix_geom_supported(frag.geom)
424    ):
425        return False
426    return frag.frag in ["x1", "x2", "x4"]
427
428
429def make_wmma_slice_ty(frag):
430    return [frag.mma_type.llvm_type] * frag.nregs
431
432
433def make_wmma_ld_ret_ty(frag):
434    results = make_wmma_slice_ty(frag)
435    if len(results) == 1:
436        return "%s" % results[0]
437    return "{%s}" % ", ".join(results)
438
439
440# returns address space
441def get_aspace(space):
442    space_map = {
443        ".global": 1,
444        ".shared": 3,
445        ".const": 4,
446        ".local": 5,
447        ".param": 101,
448        "": 0,
449        ".generic": 0,
450    }
451    return space_map[space]
452
453
454def get_pspace(space):
455    return "p%di8" % get_aspace(space)
456
457
458def check_pattern(frag):
459    return "{{%s}}" % ", *".join([frag.mma_type.ptx_reg_pattern] * frag.nregs)
460
461
462def gen_wmma_load_tests():
463    load_template = """
464declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
465
466; CHECK-LABEL: .func {{.*}}test_${function}(
467define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
468; CHECK: ${instruction}
469; CHECK: {${check_result}}
470; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
471  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
472  ret ${ret_ty} %v0;
473}
474
475; CHECK-LABEL: .func{{.*}}test_${function}_o(
476define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
477; CHECK: ${instruction}
478; CHECK: {${check_result}}
479; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
480  %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
481  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
482  ret ${ret_ty} %v0;
483}
484"""
485    intrinsic_template = (
486        "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
487    )
488    instruction_template = (
489        "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
490    )
491
492    generated_items = []
493
494    for frag, layout, space, stride in product(
495        get_ldst_ops("load"),
496        ["row", "col"],
497        ["", ".shared", ".global"],
498        ["", ".stride"],
499    ):
500        if not is_ldst_variant_supported(frag, layout):
501            continue
502
503        params = {
504            "abc": frag.frag,
505            "aligned": ".aligned" if ptx_version >= 63 else "",
506            "layout": layout,
507            "space": space,
508            "stride": stride,
509            "itype": frag.mma_type.ptx_type,
510            "pspace": get_pspace(space),
511            "as": "addrspace(%d)" % get_aspace(space),
512            "geom": frag.geom,
513        }
514
515        test_params = params
516        test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
517        test_params["function"] = test_params["intrinsic"].replace(".", "_")
518        test_params["instruction"] = Template(instruction_template).substitute(params)
519        test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
520        test_params["check_result"] = check_pattern(frag)
521
522        if stride:
523            test_params["extra_args"] = ", i32 %stride"
524            test_params["stride_pattern"] = ", %r{{[0-9]+}}"
525        else:
526            test_params["extra_args"] = ""
527            test_params["stride_pattern"] = ""
528
529        print(Template(load_template).substitute(test_params))
530
531        generated_items.append((test_params["intrinsic"], test_params["instruction"]))
532
533    return generated_items
534
535
536def make_wmma_slice_args(frag):
537    return ", ".join(
538        [
539            "%s %%%s%d" % (t, frag.frag, i)
540            for i, t in enumerate(make_wmma_slice_ty(frag))
541        ]
542    )
543
544
545def gen_wmma_store_tests():
546    store_template = """
547declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
548
549; CHECK-LABEL: .func {{.*}}test_${function}(
550define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
551; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
552; CHECK: {${check_args}}
553; CHECK: ${stride_pattern}
554  call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
555  ret void
556}
557
558; CHECK-LABEL: .func{{.*}}test_${function}_o(
559define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
560; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
561; CHECK: ${check_args}
562; CHECK: ${stride_pattern}
563  %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
564  call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
565  ret void
566}
567"""
568    intrinsic_template = (
569        "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
570    )
571    instruction_template = (
572        "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
573    )
574
575    generated_items = []
576
577    for frag, layout, space, stride in product(
578        get_ldst_ops("store"),
579        ["row", "col"],
580        ["", ".shared", ".global"],
581        ["", ".stride"],
582    ):
583
584        if not is_ldst_variant_supported(frag, layout):
585            continue
586
587        params = {
588            "abc": frag.frag,
589            "aligned": ".aligned" if ptx_version >= 63 else "",
590            "layout": layout,
591            "space": space,
592            "stride": stride,
593            "itype": frag.mma_type.ptx_type,
594            "pspace": get_pspace(space),
595            "as": "addrspace(%d)" % get_aspace(space),
596            "geom": frag.geom,
597        }
598
599        test_params = params
600        test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
601        test_params["function"] = test_params["intrinsic"].replace(".", "_")
602        test_params["instruction"] = Template(instruction_template).substitute(params)
603        test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
604        test_params["check_args"] = check_pattern(frag)
605        if stride:
606            test_params["extra_args"] = ", i32 %stride"
607            test_params["stride_pattern"] = ", %r{{[0-9]+}};"
608        else:
609            test_params["extra_args"] = ""
610            test_params["stride_pattern"] = ";"
611        test_params["args"] = make_wmma_slice_args(frag)
612
613        print(Template(store_template).substitute(test_params))
614        generated_items.append((test_params["intrinsic"], test_params["instruction"]))
615
616    return generated_items
617
618
619def gen_ldmatrix_tests():
620    ldmatrix_template = """
621declare ${ret_ty} @${intrinsic}(i8 ${as}* %src);
622
623; CHECK-LABEL: .func {{.*}}test_${function}(
624define ${ret_ty} @test_${function}(i8 ${as}* %src) {
625; CHECK: ${instruction}
626; CHECK: {${check_result}}
627; CHECK: [%rd{{[0-9]+}}]
628  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src);
629  ret ${ret_ty} %v0;
630}
631
632; CHECK-LABEL: .func{{.*}}test_${function}_o(
633define ${ret_ty} @test_${function}_o(i8 ${as}* %src) {
634; CHECK: ${instruction}
635; CHECK: {${check_result}}
636; CHECK: [%rd{{[0-9]+}}+128]
637  %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
638  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1);
639  ret ${ret_ty} %v0;
640}
641"""
642    intrinsic_template = (
643        "llvm.nvvm.ldmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
644    )
645    instruction_template = (
646        "ldmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
647    )
648
649    generated_items = []
650
651    for frag, space, trans in product(
652        get_ldmatrix_ops(),
653        ["", ".shared"],
654        ["", ".trans"],
655    ):
656        if not is_ldmatrix_variant_supported(frag):
657            continue
658
659        params = {
660            "frag": frag.frag,
661            "space": space,
662            "trans": trans,
663            "itype": frag.mma_type.ptx_type,
664            "pspace": get_pspace(space),
665            "as": "addrspace(%d)" % get_aspace(space),
666            "geom": frag.geom,
667        }
668
669        test_params = params
670        test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
671        test_params["function"] = test_params["intrinsic"].replace(".", "_")
672        test_params["instruction"] = Template(instruction_template).substitute(params)
673        test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
674        test_params["check_result"] = check_pattern(frag)
675
676        print(Template(ldmatrix_template).substitute(test_params))
677
678        generated_items.append((test_params["intrinsic"], test_params["instruction"]))
679
680    return generated_items
681
682
683def mma_signature(op):
684    if op.a.mma_type.ptx_type == "f16":
685        # FP16 ops identified by accumulator & result type.
686        return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
687    elif op.a.mma_type.ptx_type != op.b.mma_type.ptx_type:
688        # other ops are identified by input types.
689        return "%s.%s" % (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
690    else:
691        # if input types are the same, it only appears once.
692        return op.a.mma_type.ptx_type
693
694
695def mma_ptx_signature(op):
696    # Encode all four types as D.A.B.C
697    return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
698
699
700def wmma_signature(op):
701    if op.a.mma_type.ptx_type == "f16":
702        # FP16 ops identified by accumulator & result type.
703        return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
704    else:
705        # other ops are identified by input type.
706        return op.a.mma_type.ptx_type
707
708
709def wmma_ptx_signature(op):
710    if op.a.mma_type.ptx_type == "f16":
711        # FP16 instructions use D.C
712        return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
713    else:
714        # other instructions encode all four types as D.A.B.C
715        return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
716
717
718def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
719    mma_template = """
720declare ${ret_ty} @${intrinsic}(
721        ${args});
722
723; CHECK-LABEL: .func {{.*}}test_${function}(
724define ${ret_ty} @test_${function}(
725        ${args}) {
726; CHECK: ${instruction}
727; CHECK-NEXT: ${check_d}
728; CHECK-NEXT: ${check_a}
729; CHECK-NEXT: ${check_b}
730; CHECK-NEXT: ${check_c}
731  %r = call ${ret_ty} @${intrinsic}(
732        ${args});
733  ret ${ret_ty} %r;
734}
735"""
736
737    test_params = params
738    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
739    test_params["function"] = test_params["intrinsic"].replace(".", "_")
740    test_params["instruction"] = Template(instruction_template).substitute(params)
741    test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
742    test_params["check_a"] = check_pattern(op.a)
743    test_params["check_b"] = check_pattern(op.b)
744    test_params["check_c"] = check_pattern(op.c)
745    test_params["check_d"] = check_pattern(op.d)
746    args = ",\n        ".join(make_wmma_slice_args(frag) for frag in (op.a, op.b, op.c))
747    test_params["args"] = args
748    print(Template(mma_template).substitute(test_params))
749    return (test_params["intrinsic"], test_params["instruction"])
750
751
752def get_b1_ops(ptx_type):
753    if ptx_type != "b1":
754        return [""]
755    if ptx_version >= 71:
756        return [".xor.popc", ".and.popc"]
757    return [".xor.popc"]
758
759
760def gen_wmma_mma_tests():
761    wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
762    wmma_instruction_template = "wmma.mma${b1op}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
763
764    generated_items = []
765
766    for op, alayout, blayout, rnd, satf in product(
767        get_wmma_ops(),
768        ["row", "col"],
769        ["row", "col"],
770        [".rn", ".rz", ".rm", ".rp", ""],
771        [".satfinite", ""],
772    ):
773
774        if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf):
775            continue
776
777        for b1op in get_b1_ops(op.a.mma_type.ptx_type):
778            params = {
779                "aligned": ".aligned" if ptx_version >= 63 else "",
780                "alayout": alayout,
781                "blayout": blayout,
782                "intrinsic_signature": wmma_signature(op),
783                "ptx_signature": wmma_ptx_signature(op),
784                "satf": satf,
785                "rnd": rnd,
786                "geom": op.a.geom,
787                "b1op": b1op,
788            }
789
790            intrinsic_template = wmma_intrinsic_template
791            instruction_template = wmma_instruction_template
792
793            generated_items.append(
794                common_mma_test_gen(
795                    params, op, intrinsic_template, instruction_template
796                )
797            )
798
799    return generated_items
800
801
802def gen_mma_tests():
803    mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
804    mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"
805
806    generated_items = []
807
808    for op, alayout, blayout, satf in product(
809        get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""]
810    ):
811
812        if not is_mma_variant_supported(op, alayout, blayout, satf):
813            continue
814
815        for b1op in get_b1_ops(op.a.mma_type.ptx_type):
816            params = {
817                "aligned": ".aligned" if ptx_version >= 63 else "",
818                "alayout": alayout,
819                "blayout": blayout,
820                "intrinsic_signature": mma_signature(op),
821                "ptx_signature": mma_ptx_signature(op),
822                "satf": satf,
823                "geom": op.a.geom,
824                "b1op": b1op,
825            }
826
827            intrinsic_template = mma_intrinsic_template
828            instruction_template = mma_instruction_template
829
830            generated_items.append(
831                common_mma_test_gen(
832                    params, op, intrinsic_template, instruction_template
833                )
834            )
835
836    return generated_items
837
838
839# Append complete list of intrinsics and instructions we've generated tests for.
840# Generate set of checks to verify that that we did generate sensible set of
841# tests for the given combination of PTX and SM variants.
842#
843def gen_check_unsupported_ops(items):
844    print(
845        "; Complete list of intrinsics supported by PTX%d on sm_%d"
846        % (ptx_version, gpu_arch)
847    )
848    print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}")
849    print(
850        """
851
852; NOEXTGEOM-NOT: {{m8n32|m32n8}}
853; NOINT-NOT: .{{s32|s8}}
854; NOSUBINT-NOT: {{s4|u4|b1}}
855; NOMMA-NOT: .m8n8k4.
856; NOALTFLOAT-NOT: .{{bf16|tf32}}
857; NODOUBLE-NOT: .f64
858; NOLDMATRIX-NOT: ldmatrix.sync.aligned
859
860; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
861; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
862; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f32
863; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f16
864; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f16
865; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f32
866
867; PTX60 adds support for m32n8k16/m8n32k16 geometries.
868; EXTGEOM-DAG: m32n8k16.load.{{[ab].*}}.f16.p
869; EXTGEOM-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
870; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f32
871; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f16
872; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f16
873; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f32
874
875; EXTGEOM-DAG: m8n32k16.load.{{[ab].*}}.f16.p
876; EXTGEOM-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
877; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f32
878; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f16
879; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f16
880; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f32
881
882; INT-DAG: m16n16k16.load.{{[ab].*}}.s8.p
883; INT-DAG: m8n32k16.load.{{[ab].*}}.s8.p
884; INT-DAG: m32n8k16.load.{{[ab].*}}.s8.p
885; INT-DAG: m16n16k16.load.{{[ab].*}}.u8.p
886; INT-DAG: m8n32k16.load.{{[ab].*}}.u8.p
887; INT-DAG: m32n8k16.load.{{[ab].*}}.u8.p
888; INT-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p
889; INT-DAG: m16n16k16.mma.{{.*}}.u8
890; INT-DAG: m16n16k16.mma.{{.*}}.s8
891; INT-DAG: m8n32k16.mma.{{.*}}.u8
892; INT-DAG: m8n32k16.mma.{{.*}}.s8
893; INT-DAG: m32n8k16.mma.{{.*}}.u8
894; INT-DAG: m32n8k16.mma.{{.*}}.s8
895
896; SUBINT-DAG: m8n8k128.load.{{[ab].*}}.b1.p
897; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.s4.p
898; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.u4.p
899; SUBINT-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p
900; SUBINT-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p
901; SUBINT-DAG: m8n8k32.mma.{{.*}}.u4
902; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4
903; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1
904
905; ALTFLOAT-DAG: m16n16k16.load.{{[ab].*}}.bf16.p
906; ALTFLOAT-DAG: m8n32k16.load.{{[ab].*}}.bf16.p
907; ALTFLOAT-DAG: m32n8k16.load.{{[ab].*}}.bf16.p
908; ALTFLOAT-DAG: m16n16k8.load.{{[ab].*}}.tf32.p
909; ALTFLOAT-DAG: m16n16k16.mma.{{.*}}.bf16
910; ALTFLOAT-DAG: m8n32k16.mma.{{.*}}.bf16
911; ALTFLOAT-DAG: m32n8k16.mma.{{.*}}.bf16
912; ALTFLOAT-DAG: m16n16k8.mma.{{.*}}.tf32
913
914; DOUBLE-DAG: m8n8k4.load.{{[abc].*}}.f64.p
915; DOUBLE-DAG: m8n8k4.store.d.{{.*}}.f64.p
916; DOUBLE-DAG: m8n8k4.mma.{{.*}}.f64
917
918; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32
919; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16
920; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16
921; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32
922
923; PTX65MMA-DAG: mma.m16n8k8.row.col.f16.f16
924; PTX65MMA-DAG: mma.m16n8k8.row.col.f32.f32
925; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.u8
926; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.s8
927; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.u8
928; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.s8
929; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.u4
930; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.s4
931; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
932; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
933
934; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.b16
935; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.b16
936; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.b16
937; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.b16
938; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.b16
939; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.b16
940; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.shared.b16
941; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.shared.b16
942; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.shared.b16
943; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16
944; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
945; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
946
947; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
948; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
949; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
950; PTX71MMA-DAG: mma.m16n8k16.row.col.bf16
951; PTX71MMA-DAG: mma.m16n8k8.row.col.bf16
952; PTX71MMA-DAG: mma.m16n8k16.row.col.f16.f16
953; PTX71MMA-DAG: mma.m16n8k16.row.col.f32.f32
954; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
955; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
956; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
957; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
958; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
959; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
960; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
961; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
962; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
963; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
964; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
965; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
966; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
967; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
968; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
969; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
970; PTX71MMA-DAG: mma.and.popc.m8n8k128.row.col.b1
971; PTX71MMA-DAG: mma.xor.popc.m8n8k128.row.col.b1
972; PTX71MMA-DAG: mma.and.popc.m16n8k128.row.col.b1
973; PTX71MMA-DAG: mma.xor.popc.m16n8k128.row.col.b1
974; PTX71MMA-DAG: mma.and.popc.m16n8k256.row.col.b1
975; PTX71MMA-DAG: mma.xor.popc.m16n8k256.row.col.b1
976;
977
978"""
979    )
980
981    print("; INTRINSICS_LIST_BEGIN")
982    for intrinsic, instruction in sorted(items):
983        print("; ", intrinsic, " -> ", instruction, "")
984    print("; INTRINSICS_LIST_END")
985    print("; INTRINSICS: ; INTRINSICS_LIST_END")
986
987
988def gen_tests():
989    items = gen_wmma_load_tests()
990    items += gen_wmma_store_tests()
991    items += gen_ldmatrix_tests()
992    items += gen_wmma_mma_tests()
993    items += gen_mma_tests()
994    gen_check_unsupported_ops(items)
995
996
997def main():
998    global ptx_version
999    global gpu_arch
1000    parser = argparse.ArgumentParser()
1001    parser.add_argument("--ptx", type=int, default=60)
1002    parser.add_argument("--gpu-arch", type=int, default=70)
1003    args = parser.parse_args()
1004
1005    ptx_version = args.ptx
1006    gpu_arch = args.gpu_arch
1007
1008    gen_tests()
1009
1010
1011if __name__ == "__main__":
1012    main()
1013