1// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s 2// RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck --check-prefix=CHECK32 %s 3 4gpu.module @test_module { 5 6 // CHECK-LABEL: func @gpu_wmma_load_op() -> 7 // CHECK-SAME: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 8 // CHECK32-LABEL: func @gpu_wmma_load_op() -> 9 func.func @gpu_wmma_load_op() -> (!gpu.mma_matrix<16x16xf16, "AOp">) { 10 %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> 11 %i = arith.constant 16 : index 12 %j = arith.constant 16 : index 13 %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> 14 // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64 15 // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] 16 // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 17 // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64 18 // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64 19 // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64 20 // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16 21 // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 22 // CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]] 23 // CHECK-SAME: {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 24 // CHECK: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 25 26 // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 27 // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] 28 // CHECK32: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)> 29 // CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 30 // CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32 31 // CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 32 // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 33 // CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 34 // CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]] 35 // CHECK32-SAME: {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 36 // CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 37 return %0 : !gpu.mma_matrix<16x16xf16, "AOp"> 38 } 39} 40 41// ----- 42 43gpu.module @test_module { 44 45 // CHECK-LABEL: func @gpu_wmma_int8_load_op() -> 46 // CHECK-SAME: !llvm.struct<(i32, i32)> 47 // CHECK32-LABEL: func @gpu_wmma_int8_load_op() -> 48 func.func @gpu_wmma_int8_load_op() -> (!gpu.mma_matrix<16x16xsi8, "AOp">) { 49 %wg = memref.alloca() {alignment = 32} : memref<32x32xi8, 3> 50 %i = arith.constant 16 : index 51 %j = arith.constant 16 : index 52 %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xi8, 3> -> !gpu.mma_matrix<16x16xsi8, "AOp"> 53 // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64 54 // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] 55 // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 56 // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64 57 // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64 58 // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64 59 // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 60 // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 61 // CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]] 62 // CHECK-SAME: {eltype = #nvvm.mma_type<s8>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> 63 // CHECK: llvm.return %[[FRAG]] : !llvm.struct<(i32, i32)> 64 65 // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 66 // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] 67 // CHECK32: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)> 68 // CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 69 // CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32 70 // CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 71 // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 72 // CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 73 // CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]] 74 // CHECK32-SAME: {eltype = #nvvm.mma_type<s8>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> 75 // CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(i32, i32)> 76 return %0 : !gpu.mma_matrix<16x16xsi8, "AOp"> 77 } 78} 79 80// ----- 81 82gpu.module @test_module { 83 84 // CHECK-LABEL: func @gpu_wmma_store_op 85 // CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) 86 // CHECK32-LABEL: func @gpu_wmma_store_op 87 // CHECK32-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) 88 func.func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () { 89 %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3> 90 %i = arith.constant 16 : index 91 %j = arith.constant 16 : index 92 gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3> 93 // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64 94 // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] 95 // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] 96 // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] 97 // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] 98 // CHECK: %[[EL1:.*]] = llvm.extractvalue %[[D]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 99 // CHECK: %[[EL2:.*]] = llvm.extractvalue %[[D]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 100 // CHECK: %[[EL3:.*]] = llvm.extractvalue %[[D]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 101 // CHECK: %[[EL4:.*]] = llvm.extractvalue %[[D]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 102 // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 103 // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64 104 // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64 105 // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64 106 // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16 107 // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 108 // CHECK: nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]] 109 // CHECK-SAME: {eltype = #nvvm.mma_type<f16>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : !llvm.ptr<3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> 110 // CHECK: llvm.return 111 112 // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 113 // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] 114 // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] 115 // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] 116 // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] 117 // CHECK32: %[[EL1:.*]] = llvm.extractvalue %[[D]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 118 // CHECK32: %[[EL2:.*]] = llvm.extractvalue %[[D]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 119 // CHECK32: %[[EL3:.*]] = llvm.extractvalue %[[D]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 120 // CHECK32: %[[EL4:.*]] = llvm.extractvalue %[[D]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 121 // CHECK32: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)> 122 // CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 123 // CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32 124 // CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 125 // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 126 // CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 127 // CHECK32: nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]] 128 // CHECK32-SAME: {eltype = #nvvm.mma_type<f16>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : !llvm.ptr<3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> 129 // CHECK32: llvm.return 130 return 131 } 132} 133 134// ----- 135 136gpu.module @test_module { 137 138 // CHECK-LABEL: func @gpu_wmma_mma_op 139 // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) 140 func.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> (!gpu.mma_matrix<16x16xf16, "COp">) { 141 %D = gpu.subgroup_mma_compute %A, %B, %C {a_transpose} : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> 142 // CHECK: %[[A1:.*]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 143 // CHECK: %[[A2:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 144 // CHECK: %[[A3:.*]] = llvm.extractvalue %[[A]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 145 // CHECK: %[[A4:.*]] = llvm.extractvalue %[[A]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 146 // CHECK: %[[A5:.*]] = llvm.extractvalue %[[A]][4] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 147 // CHECK: %[[A6:.*]] = llvm.extractvalue %[[A]][5] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 148 // CHECK: %[[A7:.*]] = llvm.extractvalue %[[A]][6] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 149 // CHECK: %[[A8:.*]] = llvm.extractvalue %[[A]][7] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 150 // CHECK: %[[B1:.*]] = llvm.extractvalue %[[B]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 151 // CHECK: %[[B2:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 152 // CHECK: %[[B3:.*]] = llvm.extractvalue %[[B]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 153 // CHECK: %[[B4:.*]] = llvm.extractvalue %[[B]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 154 // CHECK: %[[B5:.*]] = llvm.extractvalue %[[B]][4] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 155 // CHECK: %[[B6:.*]] = llvm.extractvalue %[[B]][5] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 156 // CHECK: %[[B7:.*]] = llvm.extractvalue %[[B]][6] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 157 // CHECK: %[[B8:.*]] = llvm.extractvalue %[[B]][7] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 158 // CHECK: %[[C1:.*]] = llvm.extractvalue %[[C]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 159 // CHECK: %[[C2:.*]] = llvm.extractvalue %[[C]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 160 // CHECK: %[[C3:.*]] = llvm.extractvalue %[[C]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 161 // CHECK: %[[C4:.*]] = llvm.extractvalue %[[C]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 162 // CHECK: %[[RES:.*]] = nvvm.wmma.mma %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] 163 // CHECK-SAME: {eltypeA = #nvvm.mma_type<f16>, eltypeB = #nvvm.mma_type<f16>, k = 16 : i32, layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : ( 164 // CHECK-SAME: vector<2xf16>, {{.*}}) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 165 // CHECK: llvm.return %[[RES]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 166 return %D : !gpu.mma_matrix<16x16xf16, "COp"> 167 } 168} 169 170// ----- 171 172gpu.module @test_module { 173 174 // CHECK-LABEL: func @gpu_wmma_mma_int8_op 175 // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(i32, i32, i32, i32)>, %[[B:.*]]: !llvm.struct<(i32)>, %[[C:.*]]: !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>) 176 func.func @gpu_wmma_mma_int8_op(%A : !gpu.mma_matrix<32x16xsi8, "AOp">, %B : !gpu.mma_matrix<16x8xsi8, "BOp">, %C : !gpu.mma_matrix<32x8xi32, "COp">) -> (!gpu.mma_matrix<32x8xi32, "COp">) { 177 %D = gpu.subgroup_mma_compute %A, %B, %C {a_transpose} : !gpu.mma_matrix<32x16xsi8, "AOp">, !gpu.mma_matrix<16x8xsi8, "BOp"> -> !gpu.mma_matrix<32x8xi32, "COp"> 178 // CHECK: %[[A1:.*]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(i32, i32, i32, i32)> 179 // CHECK: %[[A2:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(i32, i32, i32, i32)> 180 // CHECK: %[[A3:.*]] = llvm.extractvalue %[[A]][2] : !llvm.struct<(i32, i32, i32, i32)> 181 // CHECK: %[[A4:.*]] = llvm.extractvalue %[[A]][3] : !llvm.struct<(i32, i32, i32, i32)> 182 // CHECK: %[[B1:.*]] = llvm.extractvalue %[[B]][0] : !llvm.struct<(i32)> 183 // CHECK: %[[C1:.*]] = llvm.extractvalue %[[C]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> 184 // CHECK: %[[C2:.*]] = llvm.extractvalue %[[C]][1] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> 185 // CHECK: %[[C3:.*]] = llvm.extractvalue %[[C]][2] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> 186 // CHECK: %[[C4:.*]] = llvm.extractvalue %[[C]][3] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> 187 // CHECK: %[[C5:.*]] = llvm.extractvalue %[[C]][4] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> 188 // CHECK: %[[C6:.*]] = llvm.extractvalue %[[C]][5] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> 189 // CHECK: %[[C7:.*]] = llvm.extractvalue %[[C]][6] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> 190 // CHECK: %[[C8:.*]] = llvm.extractvalue %[[C]][7] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> 191 // CHECK: %[[RES:.*]] = nvvm.wmma.mma %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[B1]], %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]], %[[C7]], %[[C8]] 192 // CHECK-SAME: {eltypeA = #nvvm.mma_type<s8>, eltypeB = #nvvm.mma_type<s32>, k = 16 : i32, layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<row>, m = 32 : i32, n = 8 : i32} : ( 193 // CHECK-SAME: i32, {{.*}}) -> !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> 194 // CHECK: llvm.return %[[RES]] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> 195 return %D : !gpu.mma_matrix<32x8xi32, "COp"> 196 } 197} 198 199// ----- 200 201gpu.module @test_module { 202 203// CHECK-LABEL: func @gpu_wmma_mma_loop_op 204// CHECK: %[[C:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<c>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 205// CHECK: llvm.br ^bb1(%{{.*}}, %[[C]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) 206// CHECK: ^bb1(%{{.*}}: i64, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>): // 2 preds: ^bb0, ^bb2 207// CHECK: llvm.cond_br %{{.*}}, ^bb2, ^bb3 208// CHECK: ^bb2: // pred: ^bb1 209// CHECK: %[[A:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 210// CHECK: %[[B:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<b>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 211// CHECK: %[[A0:.+]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 212// CHECK: %[[A1:.+]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 213// CHECK: %[[A2:.+]] = llvm.extractvalue %[[A]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 214// CHECK: %[[A3:.+]] = llvm.extractvalue %[[A]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 215// CHECK: %[[A4:.+]] = llvm.extractvalue %[[A]][4] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 216// CHECK: %[[A5:.+]] = llvm.extractvalue %[[A]][5] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 217// CHECK: %[[A6:.+]] = llvm.extractvalue %[[A]][6] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 218// CHECK: %[[A7:.+]] = llvm.extractvalue %[[A]][7] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 219// CHECK: %[[B0:.+]] = llvm.extractvalue %[[B]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 220// CHECK: %[[B1:.+]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 221// CHECK: %[[B2:.+]] = llvm.extractvalue %[[B]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 222// CHECK: %[[B3:.+]] = llvm.extractvalue %[[B]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 223// CHECK: %[[B4:.+]] = llvm.extractvalue %[[B]][4] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 224// CHECK: %[[B5:.+]] = llvm.extractvalue %[[B]][5] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 225// CHECK: %[[B6:.+]] = llvm.extractvalue %[[B]][6] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 226// CHECK: %[[B7:.+]] = llvm.extractvalue %[[B]][7] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 227// CHECK: %[[ACC0:.+]] = llvm.extractvalue %[[ACC]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 228// CHECK: %[[ACC1:.+]] = llvm.extractvalue %[[ACC]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 229// CHECK: %[[ACC2:.+]] = llvm.extractvalue %[[ACC]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 230// CHECK: %[[ACC3:.+]] = llvm.extractvalue %[[ACC]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 231// CHECK: %[[ACC_MUL:.+]] = nvvm.wmma.mma %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[B0]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[ACC0]], %[[ACC1]], %[[ACC2]], %[[ACC3]] {eltypeA = #nvvm.mma_type<f16>, eltypeB = #nvvm.mma_type<f16>, k = 16 : i32, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (vector<2xf16>, {{.*}} -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 232// CHECK: llvm.br ^bb1(%{{.*}}, %[[ACC_MUL]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) 233// CHECK: ^bb3: // pred: ^bb1 234// CHECK: %[[E0:.+]] = llvm.extractvalue %[[ACC]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 235// CHECK: %[[E1:.+]] = llvm.extractvalue %[[ACC]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 236// CHECK: %[[E2:.+]] = llvm.extractvalue %[[ACC]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 237// CHECK: %[[E3:.+]] = llvm.extractvalue %[[ACC]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 238// CHECK: nvvm.wmma.store %{{.*}}, %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]] {eltype = #nvvm.mma_type<f16>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> 239 240 func.func @gpu_wmma_mma_loop_op(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) { 241 %c0 = arith.constant 0 : index 242 %c128 = arith.constant 128 : index 243 %c32 = arith.constant 32 : index 244 %0 = gpu.subgroup_mma_load_matrix %arg2[%c0, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> 245 cf.br ^bb1(%c0, %0 : index, !gpu.mma_matrix<16x16xf16, "COp">) 246 ^bb1(%1: index, %2: !gpu.mma_matrix<16x16xf16, "COp">): // 2 preds: ^bb0, ^bb2 247 %3 = arith.cmpi slt, %1, %c128 : index 248 cf.cond_br %3, ^bb2, ^bb3 249 ^bb2: // pred: ^bb1 250 %4 = gpu.subgroup_mma_load_matrix %arg0[%c0, %1] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> 251 %5 = gpu.subgroup_mma_load_matrix %arg1[%1, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> 252 %6 = gpu.subgroup_mma_compute %4, %5, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> 253 %7 = arith.addi %1, %c32 : index 254 cf.br ^bb1(%7, %6 : index, !gpu.mma_matrix<16x16xf16, "COp">) 255 ^bb3: // pred: ^bb1 256 gpu.subgroup_mma_store_matrix %2, %arg2[%c0, %c0] {leadDimension = 128 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf16> 257 return 258 } 259} 260 261 262// ----- 263 264gpu.module @test_module { 265 266// CHECK-LABEL: func @gpu_wmma_constant_op 267// CHECK: %[[CST:.+]] = llvm.mlir.constant(1.000000e+00 : f16) : f16 268// CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<2xf16> 269// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 270// CHECK: %[[V1:.+]] = llvm.insertelement %[[CST]], %[[V0]][%[[C0]] : i32] : vector<2xf16> 271// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32 272// CHECK: %[[V2:.+]] = llvm.insertelement %[[CST]], %[[V1]][%[[C1]] : i32] : vector<2xf16> 273// CHECK: %[[M0:.+]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 274// CHECK: %[[M1:.+]] = llvm.insertvalue %[[V2]], %[[M0]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 275// CHECK: %[[M2:.+]] = llvm.insertvalue %[[V2]], %[[M1]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 276// CHECK: %[[M3:.+]] = llvm.insertvalue %[[V2]], %[[M2]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 277// CHECK: %[[M4:.+]] = llvm.insertvalue %[[V2]], %[[M3]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 278// CHECK: llvm.return %[[M4]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 279 func.func @gpu_wmma_constant_op() ->(!gpu.mma_matrix<16x16xf16, "COp">) { 280 %cst = arith.constant 1.0 : f16 281 %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp"> 282 return %C : !gpu.mma_matrix<16x16xf16, "COp"> 283 } 284} 285 286// ----- 287 288gpu.module @test_module { 289 290// CHECK-LABEL: func @gpu_wmma_elementwise 291// CHECK: %[[M0:.*]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 292// CHECK: %[[A0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 293// CHECK: %[[B0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 294// CHECK: %[[C0:.*]] = llvm.fadd %[[A0]], %[[B0]] : vector<2xf16> 295// CHECK: %[[M1:.*]] = llvm.insertvalue %[[C0]], %[[M0]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 296// CHECK: %[[A1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 297// CHECK: %[[B1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 298// CHECK: %[[C1:.*]] = llvm.fadd %[[A1]], %[[B1]] : vector<2xf16> 299// CHECK: %[[M2:.*]] = llvm.insertvalue %[[C1]], %[[M1]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 300// CHECK: %[[A2:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 301// CHECK: %[[B2:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 302// CHECK: %[[C2:.*]] = llvm.fadd %[[A2]], %[[B2]] : vector<2xf16> 303// CHECK: %[[M3:.*]] = llvm.insertvalue %[[C2]], %[[M2]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 304// CHECK: %[[A3:.*]] = llvm.extractvalue %{{.*}}[3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 305// CHECK: %[[B3:.*]] = llvm.extractvalue %{{.*}}[3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 306// CHECK: %[[C3:.*]] = llvm.fadd %[[A3]], %[[B3]] : vector<2xf16> 307// CHECK: %[[M4:.*]] = llvm.insertvalue %[[C3]], %[[M3]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 308 309// CHECK: %[[M0:.*]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 310// CHECK: %[[A0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 311// CHECK: %[[B0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 312// CHECK: %[[CMP0:.*]] = llvm.fcmp "ogt" %[[A0]], %[[B0]] : vector<2xf16> 313// CHECK: %[[SEL0:.*]] = llvm.select %[[CMP0]], %[[A0]], %[[B0]] : vector<2xi1>, vector<2xf16> 314// CHECK: %[[CMP1:.*]] = llvm.fcmp "uno" %[[A0]], %[[B0]] : vector<2xf16> 315// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7E00 : f16) : vector<2xf16> 316// CHECK: %[[C0:.*]] = llvm.select %[[CMP1]], %[[NAN]], %[[SEL0]] : vector<2xi1>, vector<2xf16> 317// CHECK: %[[M1:.*]] = llvm.insertvalue %[[C0]], %[[M0]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 318// CHECK: %[[A1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 319// CHECK: %[[B1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 320// CHECK: %[[CMP2:.*]] = llvm.fcmp "ogt" %[[A1]], %[[B1]] : vector<2xf16> 321// CHECK: %[[SEL1:.*]] = llvm.select %[[CMP2]], %[[A1]], %[[B1]] : vector<2xi1>, vector<2xf16> 322// CHECK: %[[CMP3:.*]] = llvm.fcmp "uno" %[[A1]], %[[B1]] : vector<2xf16> 323// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7E00 : f16) : vector<2xf16> 324// CHECK: %[[C1:.*]] = llvm.select %[[CMP3]], %[[NAN]], %[[SEL1]] : vector<2xi1>, vector<2xf16> 325// CHECK: %[[M2:.*]] = llvm.insertvalue %[[C1]], %[[M1]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 326// CHECK: %[[A2:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 327// CHECK: %[[B2:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 328// CHECK: %[[CMP4:.*]] = llvm.fcmp "ogt" %[[A2]], %[[B2]] : vector<2xf16> 329// CHECK: %[[SEL2:.*]] = llvm.select %[[CMP4]], %[[A2]], %[[B2]] : vector<2xi1>, vector<2xf16> 330// CHECK: %[[CMP5:.*]] = llvm.fcmp "uno" %[[A2]], %[[B2]] : vector<2xf16> 331// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7E00 : f16) : vector<2xf16> 332// CHECK: %[[C2:.*]] = llvm.select %[[CMP5]], %[[NAN]], %[[SEL2]] : vector<2xi1>, vector<2xf16> 333// CHECK: %[[M3:.*]] = llvm.insertvalue %[[C2]], %[[M2]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 334// CHECK: %[[A3:.*]] = llvm.extractvalue %{{.*}}[3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 335// CHECK: %[[B3:.*]] = llvm.extractvalue %{{.*}}[3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 336// CHECK: %[[CMP6:.*]] = llvm.fcmp "ogt" %[[A3]], %[[B3]] : vector<2xf16> 337// CHECK: %[[SEL3:.*]] = llvm.select %[[CMP6]], %[[A3]], %[[B3]] : vector<2xi1>, vector<2xf16> 338// CHECK: %[[CMP7:.*]] = llvm.fcmp "uno" %[[A3]], %[[B3]] : vector<2xf16> 339// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7E00 : f16) : vector<2xf16> 340// CHECK: %[[C3:.*]] = llvm.select %[[CMP7]], %[[NAN]], %[[SEL3]] : vector<2xi1>, vector<2xf16> 341// CHECK: %[[M5:.*]] = llvm.insertvalue %[[C3]], %[[M3]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 342 343// CHECK: llvm.return %[[M5]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> 344 func.func @gpu_wmma_elementwise(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) ->(!gpu.mma_matrix<16x16xf16, "COp">) { 345 %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> 346 %D = gpu.subgroup_mma_elementwise maxf %C, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> 347 return %D : !gpu.mma_matrix<16x16xf16, "COp"> 348 } 349} 350