xref: /llvm-project/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir (revision 6e90f13cc9bc9dbc5c2c248d95c6e18a5fb021b4)
1// RUN: mlir-opt --convert-gpu-to-spirv --cse \
2// RUN:   --split-input-file --verify-diagnostics %s | FileCheck %s
3
4module attributes {
5  gpu.container_module,
6  spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
7    [Shader, CooperativeMatrixKHR, Float16],
8    [SPV_KHR_storage_buffer_storage_class, SPV_KHR_cooperative_matrix]>,
9    #spirv.resource_limits<>>} {
10
11  gpu.module @kernels {
12    // CHECK-LABEL: spirv.func @gpu_wmma_load_op
13    // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
14    gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
15      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
16      %i = arith.constant 16 : index
17      %j = arith.constant 16 : index
18      // CHECK:      %[[STRIDE:.+]] = spirv.Constant 32 : i32
19      // CHECK:      spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <RowMajor> :
20      // CHECK-SAME:   !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
21      %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} :
22        memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
23
24      // CHECK:      spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <ColumnMajor> :
25      // CHECK-SAME:   !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
26      %1 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} :
27        memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
28      // CHECK: spirv.Return
29      gpu.return
30    }
31
32    // CHECK-LABEL: spirv.func @gpu_wmma_store_op
33    // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
34    // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
35    gpu.func @gpu_wmma_store_op(%arg0: memref<32x32xf16, #spirv.storage_class<StorageBuffer>>,
36                                %arg1: !gpu.mma_matrix<16x16xf16, "COp">) kernel
37      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
38      %i = arith.constant 16 : index
39      %j = arith.constant 16 : index
40      // CHECK:      %[[STRIDE:.+]] = spirv.Constant 32 : i32
41      // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <RowMajor> :
42      // CHECK-SAME:  !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
43      gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index} :
44        !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
45
46      // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <ColumnMajor> :
47      // CHECK-SAME:  !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
48      gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index, transpose} :
49        !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
50       // CHECK: spirv.Return
51      gpu.return
52    }
53
54    // CHECK-LABEL: spirv.func @gpu_wmma_mma_op
55    // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
56    // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
57    // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
58    gpu.func @gpu_wmma_mma_op(%A: !gpu.mma_matrix<16x16xf16, "AOp">,
59                              %B: !gpu.mma_matrix<16x16xf16, "BOp">,
60                              %C: !gpu.mma_matrix<16x16xf16, "COp">,
61                              %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
62      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
63      // CHECK:      %[[MAD:.*]] = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} :
64      // CHECK-SAME:   !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>,
65      // CHECK-SAME:   !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
66      // CHECK-SAME:   -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
67      %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">,
68                                                 !gpu.mma_matrix<16x16xf16, "BOp">
69                                                 -> !gpu.mma_matrix<16x16xf16, "COp">
70
71      %i = arith.constant 0 : index
72      // CHECK:      spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAD]], %{{.+}}, <RowMajor>
73      gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
74        !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
75      // CHECK: spirv.Return
76      gpu.return
77    }
78
79    // CHECK-LABEL: spirv.func @gpu_wmma_constant_op
80    gpu.func @gpu_wmma_constant_op(%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
81      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
82      // CHECK:       %[[CST1F:.+]] = spirv.Constant 1.000000e+00 : f16
83      %cst = arith.constant 1.0 : f16
84      // CHECK:       %[[MAT:.+]] = spirv.CompositeConstruct %[[CST1F]] :
85      // CHECK-SAME:   (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
86      %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
87
88      %i = arith.constant 0 : index
89      // CHECK:      spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAT]], %{{.+}}, <RowMajor>
90      gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
91        !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
92      // CHECK: spirv.Return
93      gpu.return
94    }
95
96    // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
97    // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
98    // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
99    gpu.func @gpu_wmma_elementwise_op_default(%A: !gpu.mma_matrix<16x16xf16, "COp">,
100                                              %B: !gpu.mma_matrix<16x16xf16, "COp">,
101                                              %ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
102      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
103      // CHECK:  {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
104      %C = gpu.subgroup_mma_elementwise addf %A, %B :
105        (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
106      // CHECK:  {{%.*}} = spirv.FNegate {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
107      %D = gpu.subgroup_mma_elementwise negatef %C :
108        (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
109      // CHECK:  {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
110      %E = gpu.subgroup_mma_elementwise divf %D, %A :
111        (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
112      // CHECK:  {{%.*}} = spirv.FConvert {{%.*}} :
113      // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
114      %F = gpu.subgroup_mma_elementwise extf %E :
115        (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
116
117      %i = arith.constant 0 : index
118      // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %{{.+}}, <RowMajor>
119      gpu.subgroup_mma_store_matrix %F, %ptr[%i,%i] {leadDimension = 32 : index} :
120        !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
121      // CHECK: spirv.Return
122      gpu.return
123    }
124
125    // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar
126    // CHECK-SAME:    %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
127    // CHECK-SAME:    %[[S:.+]]: f16
128    gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(
129      %A: !gpu.mma_matrix<16x16xf16, "COp">, %scalar: f16,
130      %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
131      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
132      %i = arith.constant 0 : index
133
134      %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
135      // CHECK: %[[C:.+]] = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
136      // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, <RowMajor>
137      %C = gpu.subgroup_mma_elementwise mulf %A, %B :
138        (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
139      gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
140        !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
141
142      // CHECK: %[[D:.+]] = spirv.MatrixTimesScalar %[[C]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
143      // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[D]], %{{.+}}, <RowMajor>
144      %D = gpu.subgroup_mma_elementwise mulf %B, %C :
145        (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
146      gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
147        !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
148      // CHECK: spirv.Return
149      gpu.return
150    }
151
152    // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar
153    // CHECK-SAME:    %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
154    // CHECK-SAME:    %[[S:.+]]: f16
155    gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar(
156      %A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16,
157      %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
158      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
159      %i = arith.constant 0 : index
160
161      // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
162      %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
163      // CHECK: %[[C:.+]] = spirv.FAdd %[[A]], %[[SM]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
164      %C = gpu.subgroup_mma_elementwise addf %A, %B :
165        (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
166
167      // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, <RowMajor>
168      gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
169        !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
170      // CHECK: spirv.Return
171      gpu.return
172    }
173  }
174}
175