1// RUN: mlir-opt --convert-gpu-to-spirv --cse \ 2// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s 3 4module attributes { 5 gpu.container_module, 6 spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, 7 [Shader, CooperativeMatrixKHR, Float16], 8 [SPV_KHR_storage_buffer_storage_class, SPV_KHR_cooperative_matrix]>, 9 #spirv.resource_limits<>>} { 10 11 gpu.module @kernels { 12 // CHECK-LABEL: spirv.func @gpu_wmma_load_op 13 // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> 14 gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel 15 attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} { 16 %i = arith.constant 16 : index 17 %j = arith.constant 16 : index 18 // CHECK: %[[STRIDE:.+]] = spirv.Constant 32 : i32 19 // CHECK: spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <RowMajor> : 20 // CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 21 %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} : 22 memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp"> 23 24 // CHECK: spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <ColumnMajor> : 25 // CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 26 %1 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} : 27 memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp"> 28 // CHECK: spirv.Return 29 gpu.return 30 } 31 32 // CHECK-LABEL: spirv.func @gpu_wmma_store_op 33 // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> 34 // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 35 gpu.func @gpu_wmma_store_op(%arg0: memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, 36 %arg1: !gpu.mma_matrix<16x16xf16, "COp">) kernel 37 attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} { 38 %i = arith.constant 16 : index 39 %j = arith.constant 16 : index 40 // CHECK: %[[STRIDE:.+]] = spirv.Constant 32 : i32 41 // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <RowMajor> : 42 // CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 43 gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index} : 44 !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>> 45 46 // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <ColumnMajor> : 47 // CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 48 gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index, transpose} : 49 !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>> 50 // CHECK: spirv.Return 51 gpu.return 52 } 53 54 // CHECK-LABEL: spirv.func @gpu_wmma_mma_op 55 // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA> 56 // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB> 57 // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 58 gpu.func @gpu_wmma_mma_op(%A: !gpu.mma_matrix<16x16xf16, "AOp">, 59 %B: !gpu.mma_matrix<16x16xf16, "BOp">, 60 %C: !gpu.mma_matrix<16x16xf16, "COp">, 61 %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel 62 attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} { 63 // CHECK: %[[MAD:.*]] = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : 64 // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, 65 // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB> 66 // CHECK-SAME: -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 67 %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, 68 !gpu.mma_matrix<16x16xf16, "BOp"> 69 -> !gpu.mma_matrix<16x16xf16, "COp"> 70 71 %i = arith.constant 0 : index 72 // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAD]], %{{.+}}, <RowMajor> 73 gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} : 74 !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>> 75 // CHECK: spirv.Return 76 gpu.return 77 } 78 79 // CHECK-LABEL: spirv.func @gpu_wmma_constant_op 80 gpu.func @gpu_wmma_constant_op(%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel 81 attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} { 82 // CHECK: %[[CST1F:.+]] = spirv.Constant 1.000000e+00 : f16 83 %cst = arith.constant 1.0 : f16 84 // CHECK: %[[MAT:.+]] = spirv.CompositeConstruct %[[CST1F]] : 85 // CHECK-SAME: (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 86 %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp"> 87 88 %i = arith.constant 0 : index 89 // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAT]], %{{.+}}, <RowMajor> 90 gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} : 91 !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>> 92 // CHECK: spirv.Return 93 gpu.return 94 } 95 96 // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default 97 // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 98 // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 99 gpu.func @gpu_wmma_elementwise_op_default(%A: !gpu.mma_matrix<16x16xf16, "COp">, 100 %B: !gpu.mma_matrix<16x16xf16, "COp">, 101 %ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel 102 attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} { 103 // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 104 %C = gpu.subgroup_mma_elementwise addf %A, %B : 105 (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> 106 // CHECK: {{%.*}} = spirv.FNegate {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 107 %D = gpu.subgroup_mma_elementwise negatef %C : 108 (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> 109 // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 110 %E = gpu.subgroup_mma_elementwise divf %D, %A : 111 (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> 112 // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : 113 // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> 114 %F = gpu.subgroup_mma_elementwise extf %E : 115 (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp"> 116 117 %i = arith.constant 0 : index 118 // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %{{.+}}, <RowMajor> 119 gpu.subgroup_mma_store_matrix %F, %ptr[%i,%i] {leadDimension = 32 : index} : 120 !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32, #spirv.storage_class<StorageBuffer>> 121 // CHECK: spirv.Return 122 gpu.return 123 } 124 125 // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar 126 // CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 127 // CHECK-SAME: %[[S:.+]]: f16 128 gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar( 129 %A: !gpu.mma_matrix<16x16xf16, "COp">, %scalar: f16, 130 %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel 131 attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} { 132 %i = arith.constant 0 : index 133 134 %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp"> 135 // CHECK: %[[C:.+]] = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16 136 // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, <RowMajor> 137 %C = gpu.subgroup_mma_elementwise mulf %A, %B : 138 (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> 139 gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} : 140 !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>> 141 142 // CHECK: %[[D:.+]] = spirv.MatrixTimesScalar %[[C]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16 143 // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[D]], %{{.+}}, <RowMajor> 144 %D = gpu.subgroup_mma_elementwise mulf %B, %C : 145 (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> 146 gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} : 147 !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>> 148 // CHECK: spirv.Return 149 gpu.return 150 } 151 152 // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar 153 // CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 154 // CHECK-SAME: %[[S:.+]]: f16 155 gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar( 156 %A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16, 157 %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel 158 attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} { 159 %i = arith.constant 0 : index 160 161 // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 162 %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp"> 163 // CHECK: %[[C:.+]] = spirv.FAdd %[[A]], %[[SM]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> 164 %C = gpu.subgroup_mma_elementwise addf %A, %B : 165 (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> 166 167 // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, <RowMajor> 168 gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} : 169 !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>> 170 // CHECK: spirv.Return 171 gpu.return 172 } 173 } 174} 175