xref: /llvm-project/mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir (revision 1b70587ca13bc1d372ddd3928818b7d774e21595)
1// RUN: mlir-opt --split-input-file --verify-diagnostics \
2// RUN:   --test-vector-reduction-to-spirv-dot-prod %s -o - | FileCheck %s
3
4// Positive tests.
5
6// CHECK-LABEL: func.func @to_sdot
7//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
8//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i32
9//  CHECK-NEXT:   return [[DOT]] : i32
10func.func @to_sdot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
11  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
12  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
13  %mul = arith.muli %lhs, %rhs : vector<4xi32>
14  %red = vector.reduction <add>, %mul : vector<4xi32> into i32
15  return %red : i32
16}
17
18// CHECK-LABEL: func.func @to_sdot_acc
19//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
20//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i32
21//  CHECK-NEXT:   return [[DOT]] : i32
22func.func @to_sdot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
23  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
24  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
25  %mul = arith.muli %lhs, %rhs : vector<4xi32>
26  %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
27  return %red : i32
28}
29
30// CHECK-LABEL: func.func @to_sdot_i64
31//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
32//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i64
33//  CHECK-NEXT:   return [[DOT]] : i64
34func.func @to_sdot_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i64 {
35  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64>
36  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi64>
37  %mul = arith.muli %lhs, %rhs : vector<4xi64>
38  %red = vector.reduction <add>, %mul : vector<4xi64> into i64
39  return %red : i64
40}
41
42// CHECK-LABEL: func.func @to_sdot_acc_i64
43//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i64)
44//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i64
45//  CHECK-NEXT:   return [[DOT]] : i64
46func.func @to_sdot_acc_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i64) -> i64 {
47  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64>
48  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi64>
49  %mul = arith.muli %lhs, %rhs : vector<4xi64>
50  %red = vector.reduction <add>, %mul, %acc : vector<4xi64> into i64
51  return %red : i64
52}
53
54// CHECK-LABEL: func.func @to_udot
55//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
56//  CHECK-NEXT:   [[DOT:%.+]] = spirv.UDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i32
57//  CHECK-NEXT:   return [[DOT]] : i32
58func.func @to_udot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
59  %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
60  %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
61  %mul = arith.muli %lhs, %rhs : vector<4xi32>
62  %red = vector.reduction <add>, %mul : vector<4xi32> into i32
63  return %red : i32
64}
65
66// CHECK-LABEL: func.func @to_udot_acc
67//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
68//  CHECK-NEXT:   [[DOT:%.+]] = spirv.UDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i32
69//  CHECK-NEXT:   return [[DOT]] : i32
70func.func @to_udot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
71  %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
72  %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
73  %mul = arith.muli %lhs, %rhs : vector<4xi32>
74  %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
75  return %red : i32
76}
77
78// CHECK-LABEL: func.func @to_signed_unsigned_dot
79//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
80//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i32
81//  CHECK-NEXT:   return [[DOT]] : i32
82func.func @to_signed_unsigned_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
83  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
84  %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
85  %mul = arith.muli %lhs, %rhs : vector<4xi32>
86  %red = vector.reduction <add>, %mul : vector<4xi32> into i32
87  return %red : i32
88}
89
90// CHECK-LABEL: func.func @to_signed_unsigned_dot_acc
91//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
92//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i32
93//  CHECK-NEXT:   return [[DOT]] : i32
94func.func @to_signed_unsigned_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
95  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
96  %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
97  %mul = arith.muli %lhs, %rhs : vector<4xi32>
98  %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
99  return %red : i32
100}
101
102// CHECK-LABEL: func.func @to_unsigned_signed_dot
103//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
104//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDot [[ARG1]], [[ARG0]] : vector<4xi8> -> i32
105//  CHECK-NEXT:   return [[DOT]] : i32
106func.func @to_unsigned_signed_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
107  %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
108  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
109  %mul = arith.muli %lhs, %rhs : vector<4xi32>
110  %red = vector.reduction <add>, %mul : vector<4xi32> into i32
111  return %red : i32
112}
113
114// CHECK-LABEL: func.func @to_unsigned_signed_dot_acc
115//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
116//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDotAccSat [[ARG1]], [[ARG0]], [[ACC]] : vector<4xi8> -> i32
117//  CHECK-NEXT:   return [[DOT]] : i32
118func.func @to_unsigned_signed_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
119  %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
120  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
121  %mul = arith.muli %lhs, %rhs : vector<4xi32>
122  %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
123  return %red : i32
124}
125
126// CHECK-LABEL: func.func @to_sdot_vector3
127//  CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<3xi8>)
128//       CHECK:   %[[ZERO:.+]] = spirv.Constant 0 : i8
129//       CHECK:   %[[LHS:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8>
130//       CHECK:   %[[RHS:.+]] = spirv.CompositeConstruct %[[ARG1]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8>
131//       CHECK:   %[[SDOT:.+]] = spirv.SDot %[[LHS]], %[[RHS]] : vector<4xi8> -> i32
132//       CHECK:   return %[[SDOT]]
133func.func @to_sdot_vector3(%arg0: vector<3xi8>, %arg1: vector<3xi8>) -> i32 {
134  %lhs = arith.extsi %arg0 : vector<3xi8> to vector<3xi32>
135  %rhs = arith.extsi %arg1 : vector<3xi8> to vector<3xi32>
136  %mul = arith.muli %lhs, %rhs : vector<3xi32>
137  %red = vector.reduction <add>, %mul : vector<3xi32> into i32
138  return %red : i32
139}
140
141// -----
142
143// Negative tests.
144
145// CHECK-LABEL: func.func @too_short
146//  CHECK-SAME:   ([[ARG0:%.+]]: vector<2xi8>, [[ARG1:%.+]]: vector<2xi8>)
147//  CHECK:        [[RED:%.+]] = vector.reduction
148//  CHECK-NEXT:   return [[RED]] : i32
149func.func @too_short(%arg0: vector<2xi8>, %arg1: vector<2xi8>) -> i32 {
150  %lhs = arith.extsi %arg0 : vector<2xi8> to vector<2xi32>
151  %rhs = arith.extsi %arg1 : vector<2xi8> to vector<2xi32>
152  %mul = arith.muli %lhs, %rhs : vector<2xi32>
153  %red = vector.reduction <add>, %mul : vector<2xi32> into i32
154  return %red : i32
155}
156
157// CHECK-LABEL: func.func @too_long
158//  CHECK-SAME:   ([[ARG0:%.+]]: vector<6xi8>, [[ARG1:%.+]]: vector<6xi8>)
159//  CHECK:        [[RED:%.+]] = vector.reduction
160//  CHECK-NEXT:   return [[RED]] : i32
161func.func @too_long(%arg0: vector<6xi8>, %arg1: vector<6xi8>) -> i32 {
162  %lhs = arith.extsi %arg0 : vector<6xi8> to vector<6xi32>
163  %rhs = arith.extsi %arg1 : vector<6xi8> to vector<6xi32>
164  %mul = arith.muli %lhs, %rhs : vector<6xi32>
165  %red = vector.reduction <add>, %mul : vector<6xi32> into i32
166  return %red : i32
167}
168
169// CHECK-LABEL: func.func @wrong_reduction_kind
170//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
171//  CHECK:        [[RED:%.+]] = vector.reduction <mul>
172//  CHECK-NEXT:   return [[RED]] : i32
173func.func @wrong_reduction_kind(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
174  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
175  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
176  %mul = arith.muli %lhs, %rhs : vector<4xi32>
177  %red = vector.reduction <mul>, %mul : vector<4xi32> into i32
178  return %red : i32
179}
180
181// CHECK-LABEL: func.func @wrong_arith_op
182//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
183//  CHECK:        [[ADD:%.+]] = arith.addi
184//  CHECK:        [[RED:%.+]] = vector.reduction <mul>, [[ADD]]
185//  CHECK-NEXT:   return [[RED]] : i32
186func.func @wrong_arith_op(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
187  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
188  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
189  %add = arith.addi %lhs, %rhs : vector<4xi32>
190  %red = vector.reduction <mul>, %add : vector<4xi32> into i32
191  return %red : i32
192}
193