1# REQUIRES: host-supports-nvptx 2# RUN: %PYTHON %s | FileCheck %s 3 4from mlir.ir import * 5import mlir.dialects.gpu as gpu 6import mlir.dialects.gpu.passes 7from mlir.passmanager import * 8 9 10def run(f): 11 print("\nTEST:", f.__name__) 12 with Context(), Location.unknown(): 13 f() 14 return f 15 16 17# CHECK-LABEL: testGPUToLLVMBin 18@run 19def testGPUToLLVMBin(): 20 with Context(): 21 module = Module.parse( 22 r""" 23module attributes {gpu.container_module} { 24 gpu.module @kernel_module1 [#nvvm.target<chip = "sm_70">] { 25 llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr, 26 %arg2: !llvm.ptr, %arg3: i64, %arg4: i64, 27 %arg5: i64) attributes {gpu.kernel} { 28 llvm.return 29 } 30 } 31} 32 """ 33 ) 34 pm = PassManager("any") 35 pm.add("gpu-module-to-binary{format=llvm}") 36 pm.run(module.operation) 37 # CHECK-LABEL: gpu.binary @kernel_module1 38 print(module) 39 40 o = gpu.ObjectAttr(module.body.operations[0].objects[0]) 41 # CHECK: #gpu.object<#nvvm.target<chip = "sm_70">, offload = "{{.*}}"> 42 print(o) 43 # CHECK: #nvvm.target<chip = "sm_70"> 44 print(o.target) 45 # CHECK: offload 46 print(gpu.CompilationTarget(o.format)) 47 # CHECK: b'{{.*}}' 48 print(o.object) 49 # CHECK: None 50 print(o.properties) 51 52 53# CHECK-LABEL: testGPUToASMBin 54@run 55def testGPUToASMBin(): 56 with Context(): 57 module = Module.parse( 58 r""" 59module attributes {gpu.container_module} { 60 gpu.module @kernel_module2 [#nvvm.target<flags = {fast}>, #nvvm.target] { 61 llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr, 62 %arg2: !llvm.ptr, %arg3: i64, %arg4: i64, 63 %arg5: i64) attributes {gpu.kernel} { 64 llvm.return 65 } 66 } 67} 68 """ 69 ) 70 pm = PassManager("any") 71 pm.add("gpu-module-to-binary{format=isa}") 72 pm.run(module.operation) 73 # CHECK-LABEL:gpu.binary @kernel_module2 74 print(module) 75 76 o = gpu.ObjectAttr(module.body.operations[0].objects[0]) 77 # CHECK: #gpu.object<#nvvm.target<flags = {fast}> 78 print(o) 79 # CHECK: #nvvm.target<flags = {fast}> 80 print(o.target) 81 # CHECK: assembly 82 print(gpu.CompilationTarget(o.format)) 83 # CHECK: b'//\n// Generated by LLVM NVPTX Back-End{{.*}}' 84 print(o.object) 85 # CHECK: {O = 2 : i32} 86 print(o.properties) 87