1// RUN: mlir-opt --split-input-file --verify-diagnostics \ 2// RUN: --test-vector-reduction-to-spirv-dot-prod %s -o - | FileCheck %s 3 4// Positive tests. 5 6// CHECK-LABEL: func.func @to_sdot 7// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>) 8// CHECK-NEXT: [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i32 9// CHECK-NEXT: return [[DOT]] : i32 10func.func @to_sdot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 { 11 %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32> 12 %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32> 13 %mul = arith.muli %lhs, %rhs : vector<4xi32> 14 %red = vector.reduction <add>, %mul : vector<4xi32> into i32 15 return %red : i32 16} 17 18// CHECK-LABEL: func.func @to_sdot_acc 19// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32) 20// CHECK-NEXT: [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i32 21// CHECK-NEXT: return [[DOT]] : i32 22func.func @to_sdot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 { 23 %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32> 24 %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32> 25 %mul = arith.muli %lhs, %rhs : vector<4xi32> 26 %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32 27 return %red : i32 28} 29 30// CHECK-LABEL: func.func @to_sdot_i64 31// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>) 32// CHECK-NEXT: [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i64 33// CHECK-NEXT: return [[DOT]] : i64 34func.func @to_sdot_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i64 { 35 %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64> 36 %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi64> 37 %mul = arith.muli %lhs, %rhs : vector<4xi64> 38 %red = vector.reduction <add>, %mul : vector<4xi64> into i64 39 return %red : i64 40} 41 42// CHECK-LABEL: func.func @to_sdot_acc_i64 43// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i64) 44// CHECK-NEXT: [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i64 45// CHECK-NEXT: return [[DOT]] : i64 46func.func @to_sdot_acc_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i64) -> i64 { 47 %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64> 48 %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi64> 49 %mul = arith.muli %lhs, %rhs : vector<4xi64> 50 %red = vector.reduction <add>, %mul, %acc : vector<4xi64> into i64 51 return %red : i64 52} 53 54// CHECK-LABEL: func.func @to_udot 55// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>) 56// CHECK-NEXT: [[DOT:%.+]] = spirv.UDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i32 57// CHECK-NEXT: return [[DOT]] : i32 58func.func @to_udot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 { 59 %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32> 60 %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32> 61 %mul = arith.muli %lhs, %rhs : vector<4xi32> 62 %red = vector.reduction <add>, %mul : vector<4xi32> into i32 63 return %red : i32 64} 65 66// CHECK-LABEL: func.func @to_udot_acc 67// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32) 68// CHECK-NEXT: [[DOT:%.+]] = spirv.UDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i32 69// CHECK-NEXT: return [[DOT]] : i32 70func.func @to_udot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 { 71 %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32> 72 %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32> 73 %mul = arith.muli %lhs, %rhs : vector<4xi32> 74 %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32 75 return %red : i32 76} 77 78// CHECK-LABEL: func.func @to_signed_unsigned_dot 79// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>) 80// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i32 81// CHECK-NEXT: return [[DOT]] : i32 82func.func @to_signed_unsigned_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 { 83 %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32> 84 %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32> 85 %mul = arith.muli %lhs, %rhs : vector<4xi32> 86 %red = vector.reduction <add>, %mul : vector<4xi32> into i32 87 return %red : i32 88} 89 90// CHECK-LABEL: func.func @to_signed_unsigned_dot_acc 91// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32) 92// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i32 93// CHECK-NEXT: return [[DOT]] : i32 94func.func @to_signed_unsigned_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 { 95 %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32> 96 %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32> 97 %mul = arith.muli %lhs, %rhs : vector<4xi32> 98 %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32 99 return %red : i32 100} 101 102// CHECK-LABEL: func.func @to_unsigned_signed_dot 103// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>) 104// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDot [[ARG1]], [[ARG0]] : vector<4xi8> -> i32 105// CHECK-NEXT: return [[DOT]] : i32 106func.func @to_unsigned_signed_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 { 107 %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32> 108 %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32> 109 %mul = arith.muli %lhs, %rhs : vector<4xi32> 110 %red = vector.reduction <add>, %mul : vector<4xi32> into i32 111 return %red : i32 112} 113 114// CHECK-LABEL: func.func @to_unsigned_signed_dot_acc 115// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32) 116// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDotAccSat [[ARG1]], [[ARG0]], [[ACC]] : vector<4xi8> -> i32 117// CHECK-NEXT: return [[DOT]] : i32 118func.func @to_unsigned_signed_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 { 119 %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32> 120 %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32> 121 %mul = arith.muli %lhs, %rhs : vector<4xi32> 122 %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32 123 return %red : i32 124} 125 126// CHECK-LABEL: func.func @to_sdot_vector3 127// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<3xi8>) 128// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i8 129// CHECK: %[[LHS:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8> 130// CHECK: %[[RHS:.+]] = spirv.CompositeConstruct %[[ARG1]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8> 131// CHECK: %[[SDOT:.+]] = spirv.SDot %[[LHS]], %[[RHS]] : vector<4xi8> -> i32 132// CHECK: return %[[SDOT]] 133func.func @to_sdot_vector3(%arg0: vector<3xi8>, %arg1: vector<3xi8>) -> i32 { 134 %lhs = arith.extsi %arg0 : vector<3xi8> to vector<3xi32> 135 %rhs = arith.extsi %arg1 : vector<3xi8> to vector<3xi32> 136 %mul = arith.muli %lhs, %rhs : vector<3xi32> 137 %red = vector.reduction <add>, %mul : vector<3xi32> into i32 138 return %red : i32 139} 140 141// ----- 142 143// Negative tests. 144 145// CHECK-LABEL: func.func @too_short 146// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi8>, [[ARG1:%.+]]: vector<2xi8>) 147// CHECK: [[RED:%.+]] = vector.reduction 148// CHECK-NEXT: return [[RED]] : i32 149func.func @too_short(%arg0: vector<2xi8>, %arg1: vector<2xi8>) -> i32 { 150 %lhs = arith.extsi %arg0 : vector<2xi8> to vector<2xi32> 151 %rhs = arith.extsi %arg1 : vector<2xi8> to vector<2xi32> 152 %mul = arith.muli %lhs, %rhs : vector<2xi32> 153 %red = vector.reduction <add>, %mul : vector<2xi32> into i32 154 return %red : i32 155} 156 157// CHECK-LABEL: func.func @too_long 158// CHECK-SAME: ([[ARG0:%.+]]: vector<6xi8>, [[ARG1:%.+]]: vector<6xi8>) 159// CHECK: [[RED:%.+]] = vector.reduction 160// CHECK-NEXT: return [[RED]] : i32 161func.func @too_long(%arg0: vector<6xi8>, %arg1: vector<6xi8>) -> i32 { 162 %lhs = arith.extsi %arg0 : vector<6xi8> to vector<6xi32> 163 %rhs = arith.extsi %arg1 : vector<6xi8> to vector<6xi32> 164 %mul = arith.muli %lhs, %rhs : vector<6xi32> 165 %red = vector.reduction <add>, %mul : vector<6xi32> into i32 166 return %red : i32 167} 168 169// CHECK-LABEL: func.func @wrong_reduction_kind 170// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>) 171// CHECK: [[RED:%.+]] = vector.reduction <mul> 172// CHECK-NEXT: return [[RED]] : i32 173func.func @wrong_reduction_kind(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 { 174 %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32> 175 %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32> 176 %mul = arith.muli %lhs, %rhs : vector<4xi32> 177 %red = vector.reduction <mul>, %mul : vector<4xi32> into i32 178 return %red : i32 179} 180 181// CHECK-LABEL: func.func @wrong_arith_op 182// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>) 183// CHECK: [[ADD:%.+]] = arith.addi 184// CHECK: [[RED:%.+]] = vector.reduction <mul>, [[ADD]] 185// CHECK-NEXT: return [[RED]] : i32 186func.func @wrong_arith_op(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 { 187 %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32> 188 %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32> 189 %add = arith.addi %lhs, %rhs : vector<4xi32> 190 %red = vector.reduction <mul>, %add : vector<4xi32> into i32 191 return %red : i32 192} 193