xref: /llvm-project/mlir/test/Conversion/ConvertToSPIRV/arith.mlir (revision 25ae1a266d50f24a8fffc57152d7f3c3fcb65517)
1// RUN: mlir-opt -test-convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
2
3//===----------------------------------------------------------------------===//
4// arithmetic ops
5//===----------------------------------------------------------------------===//
6
7// CHECK-LABEL: @int32_scalar
8func.func @int32_scalar(%lhs: i32, %rhs: i32) {
9  // CHECK: spirv.IAdd %{{.*}}, %{{.*}}: i32
10  %0 = arith.addi %lhs, %rhs: i32
11  // CHECK: spirv.ISub %{{.*}}, %{{.*}}: i32
12  %1 = arith.subi %lhs, %rhs: i32
13  // CHECK: spirv.IMul %{{.*}}, %{{.*}}: i32
14  %2 = arith.muli %lhs, %rhs: i32
15  // CHECK: spirv.SDiv %{{.*}}, %{{.*}}: i32
16  %3 = arith.divsi %lhs, %rhs: i32
17  // CHECK: spirv.UDiv %{{.*}}, %{{.*}}: i32
18  %4 = arith.divui %lhs, %rhs: i32
19  // CHECK: spirv.UMod %{{.*}}, %{{.*}}: i32
20  %5 = arith.remui %lhs, %rhs: i32
21  return
22}
23
24// CHECK-LABEL: @int32_scalar_srem
25// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
26func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) {
27  // CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32
28  // CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32
29  // CHECK:  %[[ABS:.+]] = spirv.UMod %[[LABS]], %[[RABS]] : i32
30  // CHECK:  %[[POS:.+]] = spirv.IEqual %[[LHS]], %[[LABS]] : i32
31  // CHECK:  %[[NEG:.+]] = spirv.SNegate %[[ABS]] : i32
32  // CHECK:      %{{.+}} = spirv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
33  %0 = arith.remsi %lhs, %rhs: i32
34  return
35}
36
37// -----
38
39//===----------------------------------------------------------------------===//
40// arith bit ops
41//===----------------------------------------------------------------------===//
42
43// CHECK-LABEL: @bitwise_scalar
44func.func @bitwise_scalar(%arg0 : i32, %arg1 : i32) {
45  // CHECK: spirv.BitwiseAnd
46  %0 = arith.andi %arg0, %arg1 : i32
47  // CHECK: spirv.BitwiseOr
48  %1 = arith.ori %arg0, %arg1 : i32
49  // CHECK: spirv.BitwiseXor
50  %2 = arith.xori %arg0, %arg1 : i32
51  return
52}
53
54// CHECK-LABEL: @bitwise_vector
55func.func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
56  // CHECK: spirv.BitwiseAnd
57  %0 = arith.andi %arg0, %arg1 : vector<4xi32>
58  // CHECK: spirv.BitwiseOr
59  %1 = arith.ori %arg0, %arg1 : vector<4xi32>
60  // CHECK: spirv.BitwiseXor
61  %2 = arith.xori %arg0, %arg1 : vector<4xi32>
62  return
63}
64
65// CHECK-LABEL: @logical_scalar
66func.func @logical_scalar(%arg0 : i1, %arg1 : i1) {
67  // CHECK: spirv.LogicalAnd
68  %0 = arith.andi %arg0, %arg1 : i1
69  // CHECK: spirv.LogicalOr
70  %1 = arith.ori %arg0, %arg1 : i1
71  // CHECK: spirv.LogicalNotEqual
72  %2 = arith.xori %arg0, %arg1 : i1
73  return
74}
75
76// CHECK-LABEL: @logical_vector
77func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
78  // CHECK: spirv.LogicalAnd
79  %0 = arith.andi %arg0, %arg1 : vector<4xi1>
80  // CHECK: spirv.LogicalOr
81  %1 = arith.ori %arg0, %arg1 : vector<4xi1>
82  // CHECK: spirv.LogicalNotEqual
83  %2 = arith.xori %arg0, %arg1 : vector<4xi1>
84  return
85}
86
87// CHECK-LABEL: @shift_scalar
88func.func @shift_scalar(%arg0 : i32, %arg1 : i32) {
89  // CHECK: spirv.ShiftLeftLogical
90  %0 = arith.shli %arg0, %arg1 : i32
91  // CHECK: spirv.ShiftRightArithmetic
92  %1 = arith.shrsi %arg0, %arg1 : i32
93  // CHECK: spirv.ShiftRightLogical
94  %2 = arith.shrui %arg0, %arg1 : i32
95  return
96}
97
98// CHECK-LABEL: @shift_vector
99func.func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
100  // CHECK: spirv.ShiftLeftLogical
101  %0 = arith.shli %arg0, %arg1 : vector<4xi32>
102  // CHECK: spirv.ShiftRightArithmetic
103  %1 = arith.shrsi %arg0, %arg1 : vector<4xi32>
104  // CHECK: spirv.ShiftRightLogical
105  %2 = arith.shrui %arg0, %arg1 : vector<4xi32>
106  return
107}
108
109// -----
110
111//===----------------------------------------------------------------------===//
112// arith.cmpf
113//===----------------------------------------------------------------------===//
114
115// CHECK-LABEL: @cmpf
116func.func @cmpf(%arg0 : f32, %arg1 : f32) {
117  // CHECK: spirv.FOrdEqual
118  %1 = arith.cmpf oeq, %arg0, %arg1 : f32
119  return
120}
121
122// CHECK-LABEL: @vec1cmpf
123func.func @vec1cmpf(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>) {
124  // CHECK: spirv.FOrdGreaterThan
125  %0 = arith.cmpf ogt, %arg0, %arg1 : vector<1xf32>
126  // CHECK: spirv.FUnordLessThan
127  %1 = arith.cmpf ult, %arg0, %arg1 : vector<1xf32>
128  return
129}
130
131// -----
132
133//===----------------------------------------------------------------------===//
134// arith.cmpi
135//===----------------------------------------------------------------------===//
136
137// CHECK-LABEL: @cmpi
138func.func @cmpi(%arg0 : i32, %arg1 : i32) {
139  // CHECK: spirv.IEqual
140  %0 = arith.cmpi eq, %arg0, %arg1 : i32
141  return
142}
143
144// CHECK-LABEL: @indexcmpi
145func.func @indexcmpi(%arg0 : index, %arg1 : index) {
146  // CHECK: spirv.IEqual
147  %0 = arith.cmpi eq, %arg0, %arg1 : index
148  return
149}
150
151// CHECK-LABEL: @vec1cmpi
152func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) {
153  // CHECK: spirv.ULessThan
154  %0 = arith.cmpi ult, %arg0, %arg1 : vector<1xi32>
155  // CHECK: spirv.SGreaterThan
156  %1 = arith.cmpi sgt, %arg0, %arg1 : vector<1xi32>
157  return
158}
159
160// CHECK-LABEL: @boolcmpi_equality
161func.func @boolcmpi_equality(%arg0 : i1, %arg1 : i1) {
162  // CHECK: spirv.LogicalEqual
163  %0 = arith.cmpi eq, %arg0, %arg1 : i1
164  // CHECK: spirv.LogicalNotEqual
165  %1 = arith.cmpi ne, %arg0, %arg1 : i1
166  return
167}
168
169// CHECK-LABEL: @boolcmpi_unsigned
170func.func @boolcmpi_unsigned(%arg0 : i1, %arg1 : i1) {
171  // CHECK-COUNT-2: spirv.Select
172  // CHECK: spirv.UGreaterThanEqual
173  %0 = arith.cmpi uge, %arg0, %arg1 : i1
174  // CHECK-COUNT-2: spirv.Select
175  // CHECK: spirv.ULessThan
176  %1 = arith.cmpi ult, %arg0, %arg1 : i1
177  return
178}
179
180// CHECK-LABEL: @vec1boolcmpi_equality
181func.func @vec1boolcmpi_equality(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
182  // CHECK: spirv.LogicalEqual
183  %0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1>
184  // CHECK: spirv.LogicalNotEqual
185  %1 = arith.cmpi ne, %arg0, %arg1 : vector<1xi1>
186  return
187}
188
189// CHECK-LABEL: @vec1boolcmpi_unsigned
190func.func @vec1boolcmpi_unsigned(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
191  // CHECK-COUNT-2: spirv.Select
192  // CHECK: spirv.UGreaterThanEqual
193  %0 = arith.cmpi uge, %arg0, %arg1 : vector<1xi1>
194  // CHECK-COUNT-2: spirv.Select
195  // CHECK: spirv.ULessThan
196  %1 = arith.cmpi ult, %arg0, %arg1 : vector<1xi1>
197  return
198}
199
200// CHECK-LABEL: @vecboolcmpi_equality
201func.func @vecboolcmpi_equality(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
202  // CHECK: spirv.LogicalEqual
203  %0 = arith.cmpi eq, %arg0, %arg1 : vector<4xi1>
204  // CHECK: spirv.LogicalNotEqual
205  %1 = arith.cmpi ne, %arg0, %arg1 : vector<4xi1>
206  return
207}
208
209// CHECK-LABEL: @vecboolcmpi_unsigned
210func.func @vecboolcmpi_unsigned(%arg0 : vector<3xi1>, %arg1 : vector<3xi1>) {
211  // CHECK-COUNT-2: spirv.Select
212  // CHECK: spirv.UGreaterThanEqual
213  %0 = arith.cmpi uge, %arg0, %arg1 : vector<3xi1>
214  // CHECK-COUNT-2: spirv.Select
215  // CHECK: spirv.ULessThan
216  %1 = arith.cmpi ult, %arg0, %arg1 : vector<3xi1>
217  return
218}
219