xref: /llvm-project/mlir/test/python/dialects/gpu/module-to-binary-nvvm.py (revision 6e6da74c8b936e457ca5e56a828823ae6a9f9066)
1# REQUIRES: host-supports-nvptx
2# RUN: %PYTHON %s | FileCheck %s
3
4from mlir.ir import *
5import mlir.dialects.gpu as gpu
6import mlir.dialects.gpu.passes
7from mlir.passmanager import *
8
9
10def run(f):
11    print("\nTEST:", f.__name__)
12    with Context(), Location.unknown():
13        f()
14    return f
15
16
17# CHECK-LABEL: testGPUToLLVMBin
18@run
19def testGPUToLLVMBin():
20    with Context():
21        module = Module.parse(
22            r"""
23module attributes {gpu.container_module} {
24  gpu.module @kernel_module1 [#nvvm.target<chip = "sm_70">] {
25    llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr,
26        %arg2: !llvm.ptr, %arg3: i64, %arg4: i64,
27        %arg5: i64) attributes {gpu.kernel} {
28      llvm.return
29    }
30  }
31}
32    """
33        )
34    pm = PassManager("any")
35    pm.add("gpu-module-to-binary{format=llvm}")
36    pm.run(module.operation)
37    # CHECK-LABEL: gpu.binary @kernel_module1
38    print(module)
39
40    o = gpu.ObjectAttr(module.body.operations[0].objects[0])
41    # CHECK: #gpu.object<#nvvm.target<chip = "sm_70">, offload = "{{.*}}">
42    print(o)
43    # CHECK: #nvvm.target<chip = "sm_70">
44    print(o.target)
45    # CHECK: offload
46    print(gpu.CompilationTarget(o.format))
47    # CHECK: b'{{.*}}'
48    print(o.object)
49    # CHECK: None
50    print(o.properties)
51
52
53# CHECK-LABEL: testGPUToASMBin
54@run
55def testGPUToASMBin():
56    with Context():
57        module = Module.parse(
58            r"""
59module attributes {gpu.container_module} {
60  gpu.module @kernel_module2 [#nvvm.target<flags = {fast}>, #nvvm.target] {
61    llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr,
62        %arg2: !llvm.ptr, %arg3: i64, %arg4: i64,
63        %arg5: i64) attributes {gpu.kernel} {
64      llvm.return
65    }
66  }
67}
68    """
69        )
70    pm = PassManager("any")
71    pm.add("gpu-module-to-binary{format=isa}")
72    pm.run(module.operation)
73    # CHECK-LABEL:gpu.binary @kernel_module2
74    print(module)
75
76    o = gpu.ObjectAttr(module.body.operations[0].objects[0])
77    # CHECK: #gpu.object<#nvvm.target<flags = {fast}>
78    print(o)
79    # CHECK: #nvvm.target<flags = {fast}>
80    print(o.target)
81    # CHECK: assembly
82    print(gpu.CompilationTarget(o.format))
83    # CHECK: b'//\n// Generated by LLVM NVPTX Back-End{{.*}}'
84    print(o.object)
85    # CHECK: {O = 2 : i32}
86    print(o.properties)
87