xref: /llvm-project/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir (revision 02307a14440ee7270ecc8490ad6e390261bb21c4)
1// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s
2// RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck --check-prefix=CHECK32 %s
3
4gpu.module @test_module {
5
6  // CHECK-LABEL: func @gpu_wmma_load_op() ->
7  // CHECK-SAME: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
8  // CHECK32-LABEL: func @gpu_wmma_load_op() ->
9  func.func @gpu_wmma_load_op() -> (!gpu.mma_matrix<16x16xf16, "AOp">) {
10    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
11    %i = arith.constant 16 : index
12    %j = arith.constant 16 : index
13    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
14    // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
15    // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
16    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
17    // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
18    // CHECK:  %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]]  : i64
19    // CHECK:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i64
20    // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16
21    // CHECK:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
22    // CHECK:  %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
23    // CHECK-SAME: {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32}  : (!llvm.ptr<3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
24    // CHECK:  llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
25
26    // CHECK32:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
27    // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
28    // CHECK32:  %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)>
29    // CHECK32:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
30    // CHECK32:  %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]]  : i32
31    // CHECK32:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i32
32    // CHECK32:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
33    // CHECK32:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
34    // CHECK32:  %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
35    // CHECK32-SAME: {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32}  : (!llvm.ptr<3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
36    // CHECK32:  llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
37    return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
38  }
39}
40
41// -----
42
43gpu.module @test_module {
44
45  // CHECK-LABEL: func @gpu_wmma_int8_load_op() ->
46  // CHECK-SAME: !llvm.struct<(i32, i32)>
47  // CHECK32-LABEL: func @gpu_wmma_int8_load_op() ->
48  func.func @gpu_wmma_int8_load_op() -> (!gpu.mma_matrix<16x16xsi8, "AOp">) {
49    %wg = memref.alloca() {alignment = 32} : memref<32x32xi8, 3>
50    %i = arith.constant 16 : index
51    %j = arith.constant 16 : index
52    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xi8, 3> -> !gpu.mma_matrix<16x16xsi8, "AOp">
53    // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
54    // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
55    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
56    // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
57    // CHECK:  %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]]  : i64
58    // CHECK:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i64
59    // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8
60    // CHECK:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
61    // CHECK:  %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
62    // CHECK-SAME: {eltype = #nvvm.mma_type<s8>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32}  : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
63    // CHECK:  llvm.return %[[FRAG]] : !llvm.struct<(i32, i32)>
64
65    // CHECK32:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
66    // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
67    // CHECK32:  %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)>
68    // CHECK32:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
69    // CHECK32:  %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]]  : i32
70    // CHECK32:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i32
71    // CHECK32:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
72    // CHECK32:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
73    // CHECK32:  %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
74    // CHECK32-SAME: {eltype = #nvvm.mma_type<s8>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32}  : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
75    // CHECK32:  llvm.return %[[FRAG]] : !llvm.struct<(i32, i32)>
76    return %0 : !gpu.mma_matrix<16x16xsi8, "AOp">
77  }
78}
79
80// -----
81
82gpu.module @test_module {
83
84  // CHECK-LABEL: func @gpu_wmma_store_op
85  // CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
86  // CHECK32-LABEL: func @gpu_wmma_store_op
87  // CHECK32-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
88  func.func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
89    %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
90    %i = arith.constant 16 : index
91    %j = arith.constant 16 : index
92    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
93    // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
94    // CHECK:  %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
95    // CHECK:  %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
96    // CHECK:  %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
97    // CHECK:  %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
98    // CHECK:  %[[EL1:.*]] = llvm.extractvalue %[[D]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
99    // CHECK:  %[[EL2:.*]] = llvm.extractvalue %[[D]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
100    // CHECK:  %[[EL3:.*]] = llvm.extractvalue %[[D]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
101    // CHECK:  %[[EL4:.*]] = llvm.extractvalue %[[D]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
102    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
103    // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
104    // CHECK:  %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]]   : i64
105    // CHECK:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i64
106    // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16
107    // CHECK:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
108    // CHECK:  nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
109    // CHECK-SAME: {eltype = #nvvm.mma_type<f16>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : !llvm.ptr<3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
110    // CHECK:  llvm.return
111
112    // CHECK32:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
113    // CHECK32:  %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
114    // CHECK32:  %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
115    // CHECK32:  %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
116    // CHECK32:  %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
117    // CHECK32:  %[[EL1:.*]] = llvm.extractvalue %[[D]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
118    // CHECK32:  %[[EL2:.*]] = llvm.extractvalue %[[D]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
119    // CHECK32:  %[[EL3:.*]] = llvm.extractvalue %[[D]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
120    // CHECK32:  %[[EL4:.*]] = llvm.extractvalue %[[D]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
121    // CHECK32:  %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)>
122    // CHECK32:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
123    // CHECK32:  %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]]   : i32
124    // CHECK32:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i32
125    // CHECK32:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
126    // CHECK32:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
127    // CHECK32:  nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
128    // CHECK32-SAME: {eltype = #nvvm.mma_type<f16>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : !llvm.ptr<3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
129    // CHECK32:  llvm.return
130    return
131  }
132}
133
134// -----
135
136gpu.module @test_module {
137
138  // CHECK-LABEL: func @gpu_wmma_mma_op
139  // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
140  func.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> (!gpu.mma_matrix<16x16xf16, "COp">) {
141    %D = gpu.subgroup_mma_compute %A, %B, %C {a_transpose} : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
142    // CHECK:  %[[A1:.*]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
143    // CHECK:  %[[A2:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
144    // CHECK:  %[[A3:.*]] = llvm.extractvalue %[[A]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
145    // CHECK:  %[[A4:.*]] = llvm.extractvalue %[[A]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
146    // CHECK:  %[[A5:.*]] = llvm.extractvalue %[[A]][4] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
147    // CHECK:  %[[A6:.*]] = llvm.extractvalue %[[A]][5] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
148    // CHECK:  %[[A7:.*]] = llvm.extractvalue %[[A]][6] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
149    // CHECK:  %[[A8:.*]] = llvm.extractvalue %[[A]][7] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
150    // CHECK:  %[[B1:.*]] = llvm.extractvalue %[[B]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
151    // CHECK:  %[[B2:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
152    // CHECK:  %[[B3:.*]] = llvm.extractvalue %[[B]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
153    // CHECK:  %[[B4:.*]] = llvm.extractvalue %[[B]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
154    // CHECK:  %[[B5:.*]] = llvm.extractvalue %[[B]][4] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
155    // CHECK:  %[[B6:.*]] = llvm.extractvalue %[[B]][5] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
156    // CHECK:  %[[B7:.*]] = llvm.extractvalue %[[B]][6] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
157    // CHECK:  %[[B8:.*]] = llvm.extractvalue %[[B]][7] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
158    // CHECK:  %[[C1:.*]] = llvm.extractvalue %[[C]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
159    // CHECK:  %[[C2:.*]] = llvm.extractvalue %[[C]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
160    // CHECK:  %[[C3:.*]] = llvm.extractvalue %[[C]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
161    // CHECK:  %[[C4:.*]] = llvm.extractvalue %[[C]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
162    // CHECK:  %[[RES:.*]] = nvvm.wmma.mma %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]]
163    // CHECK-SAME: {eltypeA = #nvvm.mma_type<f16>, eltypeB = #nvvm.mma_type<f16>, k = 16 : i32, layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (
164    // CHECK-SAME: vector<2xf16>, {{.*}}) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
165    // CHECK:  llvm.return %[[RES]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
166    return %D : !gpu.mma_matrix<16x16xf16, "COp">
167  }
168}
169
170// -----
171
172gpu.module @test_module {
173
174  // CHECK-LABEL: func @gpu_wmma_mma_int8_op
175  // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(i32, i32, i32, i32)>, %[[B:.*]]: !llvm.struct<(i32)>, %[[C:.*]]: !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>)
176  func.func @gpu_wmma_mma_int8_op(%A : !gpu.mma_matrix<32x16xsi8, "AOp">, %B : !gpu.mma_matrix<16x8xsi8, "BOp">, %C : !gpu.mma_matrix<32x8xi32, "COp">) -> (!gpu.mma_matrix<32x8xi32, "COp">) {
177    %D = gpu.subgroup_mma_compute %A, %B, %C {a_transpose} : !gpu.mma_matrix<32x16xsi8, "AOp">, !gpu.mma_matrix<16x8xsi8, "BOp"> -> !gpu.mma_matrix<32x8xi32, "COp">
178    // CHECK:  %[[A1:.*]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(i32, i32, i32, i32)>
179    // CHECK:  %[[A2:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(i32, i32, i32, i32)>
180    // CHECK:  %[[A3:.*]] = llvm.extractvalue %[[A]][2] : !llvm.struct<(i32, i32, i32, i32)>
181    // CHECK:  %[[A4:.*]] = llvm.extractvalue %[[A]][3] : !llvm.struct<(i32, i32, i32, i32)>
182    // CHECK:  %[[B1:.*]] = llvm.extractvalue %[[B]][0] : !llvm.struct<(i32)>
183    // CHECK:  %[[C1:.*]] = llvm.extractvalue %[[C]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
184    // CHECK:  %[[C2:.*]] = llvm.extractvalue %[[C]][1] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
185    // CHECK:  %[[C3:.*]] = llvm.extractvalue %[[C]][2] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
186    // CHECK:  %[[C4:.*]] = llvm.extractvalue %[[C]][3] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
187    // CHECK:  %[[C5:.*]] = llvm.extractvalue %[[C]][4] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
188    // CHECK:  %[[C6:.*]] = llvm.extractvalue %[[C]][5] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
189    // CHECK:  %[[C7:.*]] = llvm.extractvalue %[[C]][6] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
190    // CHECK:  %[[C8:.*]] = llvm.extractvalue %[[C]][7] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
191    // CHECK:  %[[RES:.*]] = nvvm.wmma.mma %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[B1]], %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]], %[[C7]], %[[C8]]
192    // CHECK-SAME: {eltypeA = #nvvm.mma_type<s8>, eltypeB = #nvvm.mma_type<s32>, k = 16 : i32, layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<row>, m = 32 : i32, n = 8 : i32} : (
193    // CHECK-SAME: i32, {{.*}}) -> !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
194    // CHECK:  llvm.return %[[RES]] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
195    return %D : !gpu.mma_matrix<32x8xi32, "COp">
196  }
197}
198
199// -----
200
201gpu.module @test_module {
202
203// CHECK-LABEL: func @gpu_wmma_mma_loop_op
204//       CHECK:   %[[C:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<c>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
205//       CHECK:   llvm.br ^bb1(%{{.*}}, %[[C]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
206//       CHECK:  ^bb1(%{{.*}}: i64, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>):  // 2 preds: ^bb0, ^bb2
207//       CHECK:   llvm.cond_br %{{.*}}, ^bb2, ^bb3
208//       CHECK:  ^bb2:  // pred: ^bb1
209//       CHECK:   %[[A:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
210//       CHECK:   %[[B:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<b>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
211//       CHECK:   %[[A0:.+]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
212//       CHECK:   %[[A1:.+]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
213//       CHECK:   %[[A2:.+]] = llvm.extractvalue %[[A]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
214//       CHECK:   %[[A3:.+]] = llvm.extractvalue %[[A]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
215//       CHECK:   %[[A4:.+]] = llvm.extractvalue %[[A]][4] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
216//       CHECK:   %[[A5:.+]] = llvm.extractvalue %[[A]][5] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
217//       CHECK:   %[[A6:.+]] = llvm.extractvalue %[[A]][6] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
218//       CHECK:   %[[A7:.+]] = llvm.extractvalue %[[A]][7] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
219//       CHECK:   %[[B0:.+]] = llvm.extractvalue %[[B]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
220//       CHECK:   %[[B1:.+]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
221//       CHECK:   %[[B2:.+]] = llvm.extractvalue %[[B]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
222//       CHECK:   %[[B3:.+]] = llvm.extractvalue %[[B]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
223//       CHECK:   %[[B4:.+]] = llvm.extractvalue %[[B]][4] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
224//       CHECK:   %[[B5:.+]] = llvm.extractvalue %[[B]][5] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
225//       CHECK:   %[[B6:.+]] = llvm.extractvalue %[[B]][6] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
226//       CHECK:   %[[B7:.+]] = llvm.extractvalue %[[B]][7] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
227//       CHECK:   %[[ACC0:.+]] = llvm.extractvalue %[[ACC]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
228//       CHECK:   %[[ACC1:.+]] = llvm.extractvalue %[[ACC]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
229//       CHECK:   %[[ACC2:.+]] = llvm.extractvalue %[[ACC]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
230//       CHECK:   %[[ACC3:.+]] = llvm.extractvalue %[[ACC]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
231//       CHECK:   %[[ACC_MUL:.+]] = nvvm.wmma.mma %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[B0]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[ACC0]], %[[ACC1]], %[[ACC2]], %[[ACC3]] {eltypeA = #nvvm.mma_type<f16>, eltypeB = #nvvm.mma_type<f16>, k = 16 : i32, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (vector<2xf16>, {{.*}} -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
232//       CHECK:   llvm.br ^bb1(%{{.*}}, %[[ACC_MUL]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
233//       CHECK:  ^bb3:  // pred: ^bb1
234//       CHECK:   %[[E0:.+]] = llvm.extractvalue %[[ACC]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
235//       CHECK:   %[[E1:.+]] = llvm.extractvalue %[[ACC]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
236//       CHECK:   %[[E2:.+]] = llvm.extractvalue %[[ACC]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
237//       CHECK:   %[[E3:.+]] = llvm.extractvalue %[[ACC]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
238//       CHECK:   nvvm.wmma.store %{{.*}}, %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]] {eltype = #nvvm.mma_type<f16>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
239
240  func.func @gpu_wmma_mma_loop_op(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) {
241      %c0 = arith.constant 0 : index
242      %c128 = arith.constant 128 : index
243      %c32 = arith.constant 32 : index
244      %0 = gpu.subgroup_mma_load_matrix %arg2[%c0, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
245      cf.br ^bb1(%c0, %0 : index, !gpu.mma_matrix<16x16xf16, "COp">)
246    ^bb1(%1: index, %2: !gpu.mma_matrix<16x16xf16, "COp">):  // 2 preds: ^bb0, ^bb2
247      %3 = arith.cmpi slt, %1, %c128 : index
248      cf.cond_br %3, ^bb2, ^bb3
249    ^bb2:  // pred: ^bb1
250      %4 = gpu.subgroup_mma_load_matrix %arg0[%c0, %1] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
251      %5 = gpu.subgroup_mma_load_matrix %arg1[%1, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
252      %6 = gpu.subgroup_mma_compute %4, %5, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
253      %7 = arith.addi %1, %c32 : index
254      cf.br ^bb1(%7, %6 : index, !gpu.mma_matrix<16x16xf16, "COp">)
255    ^bb3:  // pred: ^bb1
256      gpu.subgroup_mma_store_matrix %2, %arg2[%c0, %c0] {leadDimension = 128 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf16>
257      return
258    }
259}
260
261
262// -----
263
264gpu.module @test_module {
265
266// CHECK-LABEL: func @gpu_wmma_constant_op
267//       CHECK: %[[CST:.+]] = llvm.mlir.constant(1.000000e+00 : f16) : f16
268//       CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<2xf16>
269//       CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
270//       CHECK: %[[V1:.+]] = llvm.insertelement %[[CST]], %[[V0]][%[[C0]] : i32] : vector<2xf16>
271//       CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
272//       CHECK: %[[V2:.+]] = llvm.insertelement %[[CST]], %[[V1]][%[[C1]] : i32] : vector<2xf16>
273//       CHECK: %[[M0:.+]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
274//       CHECK: %[[M1:.+]] = llvm.insertvalue %[[V2]], %[[M0]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
275//       CHECK: %[[M2:.+]] = llvm.insertvalue %[[V2]], %[[M1]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
276//       CHECK: %[[M3:.+]] = llvm.insertvalue %[[V2]], %[[M2]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
277//       CHECK: %[[M4:.+]] = llvm.insertvalue %[[V2]], %[[M3]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
278//       CHECK: llvm.return %[[M4]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
279  func.func @gpu_wmma_constant_op()  ->(!gpu.mma_matrix<16x16xf16, "COp">) {
280    %cst = arith.constant 1.0 : f16
281    %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
282    return %C : !gpu.mma_matrix<16x16xf16, "COp">
283  }
284}
285
286// -----
287
288gpu.module @test_module {
289
290// CHECK-LABEL: func @gpu_wmma_elementwise
291//       CHECK: %[[M0:.*]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
292//       CHECK: %[[A0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
293//       CHECK: %[[B0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
294//       CHECK: %[[C0:.*]] = llvm.fadd %[[A0]], %[[B0]]  : vector<2xf16>
295//       CHECK: %[[M1:.*]] = llvm.insertvalue %[[C0]], %[[M0]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
296//       CHECK: %[[A1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
297//       CHECK: %[[B1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
298//       CHECK: %[[C1:.*]] = llvm.fadd %[[A1]], %[[B1]]  : vector<2xf16>
299//       CHECK: %[[M2:.*]] = llvm.insertvalue %[[C1]], %[[M1]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
300//       CHECK: %[[A2:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
301//       CHECK: %[[B2:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
302//       CHECK: %[[C2:.*]] = llvm.fadd %[[A2]], %[[B2]]  : vector<2xf16>
303//       CHECK: %[[M3:.*]] = llvm.insertvalue %[[C2]], %[[M2]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
304//       CHECK: %[[A3:.*]] = llvm.extractvalue %{{.*}}[3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
305//       CHECK: %[[B3:.*]] = llvm.extractvalue %{{.*}}[3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
306//       CHECK: %[[C3:.*]] = llvm.fadd %[[A3]], %[[B3]]  : vector<2xf16>
307//       CHECK: %[[M4:.*]] = llvm.insertvalue %[[C3]], %[[M3]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
308
309//       CHECK: %[[M0:.*]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
310//       CHECK: %[[A0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
311//       CHECK: %[[B0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
312//       CHECK: %[[CMP0:.*]] = llvm.fcmp "ogt" %[[A0]], %[[B0]] : vector<2xf16>
313//       CHECK: %[[SEL0:.*]] = llvm.select %[[CMP0]], %[[A0]], %[[B0]] : vector<2xi1>, vector<2xf16>
314//       CHECK: %[[CMP1:.*]] = llvm.fcmp "uno" %[[A0]], %[[B0]] : vector<2xf16>
315//       CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7E00 : f16) : vector<2xf16>
316//       CHECK: %[[C0:.*]] = llvm.select %[[CMP1]], %[[NAN]], %[[SEL0]] : vector<2xi1>, vector<2xf16>
317//       CHECK: %[[M1:.*]] = llvm.insertvalue %[[C0]], %[[M0]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
318//       CHECK: %[[A1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
319//       CHECK: %[[B1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
320//       CHECK: %[[CMP2:.*]] = llvm.fcmp "ogt" %[[A1]], %[[B1]] : vector<2xf16>
321//       CHECK: %[[SEL1:.*]] = llvm.select %[[CMP2]], %[[A1]], %[[B1]] : vector<2xi1>, vector<2xf16>
322//       CHECK: %[[CMP3:.*]] = llvm.fcmp "uno" %[[A1]], %[[B1]] : vector<2xf16>
323//       CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7E00 : f16) : vector<2xf16>
324//       CHECK: %[[C1:.*]] = llvm.select %[[CMP3]], %[[NAN]], %[[SEL1]] : vector<2xi1>, vector<2xf16>
325//       CHECK: %[[M2:.*]] = llvm.insertvalue %[[C1]], %[[M1]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
326//       CHECK: %[[A2:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
327//       CHECK: %[[B2:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
328//       CHECK: %[[CMP4:.*]] = llvm.fcmp "ogt" %[[A2]], %[[B2]] : vector<2xf16>
329//       CHECK: %[[SEL2:.*]] = llvm.select %[[CMP4]], %[[A2]], %[[B2]] : vector<2xi1>, vector<2xf16>
330//       CHECK: %[[CMP5:.*]] = llvm.fcmp "uno" %[[A2]], %[[B2]] : vector<2xf16>
331//       CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7E00 : f16) : vector<2xf16>
332//       CHECK: %[[C2:.*]] = llvm.select %[[CMP5]], %[[NAN]], %[[SEL2]] : vector<2xi1>, vector<2xf16>
333//       CHECK: %[[M3:.*]] = llvm.insertvalue %[[C2]], %[[M2]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
334//       CHECK: %[[A3:.*]] = llvm.extractvalue %{{.*}}[3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
335//       CHECK: %[[B3:.*]] = llvm.extractvalue %{{.*}}[3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
336//       CHECK: %[[CMP6:.*]] = llvm.fcmp "ogt" %[[A3]], %[[B3]] : vector<2xf16>
337//       CHECK: %[[SEL3:.*]] = llvm.select %[[CMP6]], %[[A3]], %[[B3]] : vector<2xi1>, vector<2xf16>
338//       CHECK: %[[CMP7:.*]] = llvm.fcmp "uno" %[[A3]], %[[B3]] : vector<2xf16>
339//       CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7E00 : f16) : vector<2xf16>
340//       CHECK: %[[C3:.*]] = llvm.select %[[CMP7]], %[[NAN]], %[[SEL3]] : vector<2xi1>, vector<2xf16>
341//       CHECK: %[[M5:.*]] = llvm.insertvalue %[[C3]], %[[M3]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
342
343//       CHECK: llvm.return %[[M5]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
344  func.func @gpu_wmma_elementwise(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">)  ->(!gpu.mma_matrix<16x16xf16, "COp">) {
345    %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
346    %D = gpu.subgroup_mma_elementwise maxf %C, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
347    return %D : !gpu.mma_matrix<16x16xf16, "COp">
348  }
349}
350