xref: /llvm-project/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir (revision 40e2bb5330840b56d452244f96e491b6530ce4bf)
1// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
2// RUN: mlir-opt --split-input-file --verify-diagnostics --canonicalize %s \
3// RUN:  | FileCheck %s --check-prefix=CANON
4
5//===----------------------------------------------------------------------===//
6// spirv.BitCount
7//===----------------------------------------------------------------------===//
8
9func.func @bitcount(%arg: i32) -> i32 {
10  // CHECK: spirv.BitCount {{%.*}} : i32
11  %0 = spirv.BitCount %arg : i32
12  spirv.ReturnValue %0 : i32
13}
14
15// -----
16
17//===----------------------------------------------------------------------===//
18// spirv.BitFieldInsert
19//===----------------------------------------------------------------------===//
20
21func.func @bit_field_insert_vec(%base: vector<3xi32>, %insert: vector<3xi32>, %offset: i32, %count: i16) -> vector<3xi32> {
22  // CHECK: {{%.*}} = spirv.BitFieldInsert {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<3xi32>, i32, i16
23  %0 = spirv.BitFieldInsert %base, %insert, %offset, %count : vector<3xi32>, i32, i16
24  spirv.ReturnValue %0 : vector<3xi32>
25}
26
27// -----
28
29func.func @bit_field_insert_invalid_insert_type(%base: vector<3xi32>, %insert: vector<2xi32>, %offset: i32, %count: i16) -> vector<3xi32> {
30  // TODO: expand post change in verification order. This is currently only
31  // verifying that the type verification is failing but not the specific error
32  // message. In final state the error should refer to mismatch in base and
33  // insert.
34  // expected-error @+1 {{type}}
35  %0 = "spirv.BitFieldInsert" (%base, %insert, %offset, %count) : (vector<3xi32>, vector<2xi32>, i32, i16) -> vector<3xi32>
36  spirv.ReturnValue %0 : vector<3xi32>
37}
38
39// -----
40
41//===----------------------------------------------------------------------===//
42// spirv.BitFieldSExtract
43//===----------------------------------------------------------------------===//
44
45func.func @bit_field_s_extract_vec(%base: vector<3xi32>, %offset: i8, %count: i8) -> vector<3xi32> {
46  // CHECK: {{%.*}} = spirv.BitFieldSExtract {{%.*}}, {{%.*}}, {{%.*}} : vector<3xi32>, i8, i8
47  %0 = spirv.BitFieldSExtract %base, %offset, %count : vector<3xi32>, i8, i8
48  spirv.ReturnValue %0 : vector<3xi32>
49}
50
51//===----------------------------------------------------------------------===//
52// spirv.BitFieldUExtract
53//===----------------------------------------------------------------------===//
54
55func.func @bit_field_u_extract_vec(%base: vector<3xi32>, %offset: i8, %count: i8) -> vector<3xi32> {
56  // CHECK: {{%.*}} = spirv.BitFieldUExtract {{%.*}}, {{%.*}}, {{%.*}} : vector<3xi32>, i8, i8
57  %0 = spirv.BitFieldUExtract %base, %offset, %count : vector<3xi32>, i8, i8
58  spirv.ReturnValue %0 : vector<3xi32>
59}
60
61// -----
62
63func.func @bit_field_u_extract_invalid_result_type(%base: vector<3xi32>, %offset: i32, %count: i16) -> vector<4xi32> {
64  // expected-error @+1 {{failed to verify that all of {base, result} have same type}}
65  %0 = "spirv.BitFieldUExtract" (%base, %offset, %count) : (vector<3xi32>, i32, i16) -> vector<4xi32>
66  spirv.ReturnValue %0 : vector<4xi32>
67}
68
69// -----
70
71//===----------------------------------------------------------------------===//
72// spirv.BitReverse
73//===----------------------------------------------------------------------===//
74
75func.func @bitreverse(%arg: i32) -> i32 {
76  // CHECK: spirv.BitReverse {{%.*}} : i32
77  %0 = spirv.BitReverse %arg : i32
78  spirv.ReturnValue %0 : i32
79}
80
81// -----
82
83//===----------------------------------------------------------------------===//
84// spirv.BitwiseOr
85//===----------------------------------------------------------------------===//
86
87// CHECK-LABEL: func @bitwise_or_scalar
88func.func @bitwise_or_scalar(%arg: i32) -> i32 {
89  // CHECK: spirv.BitwiseOr
90  %0 = spirv.BitwiseOr %arg, %arg : i32
91  return %0 : i32
92}
93
94// CHECK-LABEL: func @bitwise_or_vector
95func.func @bitwise_or_vector(%arg: vector<4xi32>) -> vector<4xi32> {
96  // CHECK: spirv.BitwiseOr
97  %0 = spirv.BitwiseOr %arg, %arg : vector<4xi32>
98  return %0 : vector<4xi32>
99}
100
101// CANON-LABEL: func @bitwise_or_zero
102// CANON-SAME:    (%[[ARG:.+]]: i32)
103func.func @bitwise_or_zero(%arg: i32) -> i32 {
104  // CANON: return %[[ARG]]
105  %zero = spirv.Constant 0 : i32
106  %0 = spirv.BitwiseOr %arg, %zero : i32
107  return %0 : i32
108}
109
110// CANON-LABEL: func @bitwise_or_zero_vector
111// CANON-SAME:    (%[[ARG:.+]]: vector<4xi32>)
112func.func @bitwise_or_zero_vector(%arg: vector<4xi32>) -> vector<4xi32> {
113  // CANON: return %[[ARG]]
114  %zero = spirv.Constant dense<0> : vector<4xi32>
115  %0 = spirv.BitwiseOr %arg, %zero : vector<4xi32>
116  return %0 : vector<4xi32>
117}
118
119// CANON-LABEL: func @bitwise_or_all_ones
120func.func @bitwise_or_all_ones(%arg: i8) -> i8 {
121  // CANON: %[[CST:.+]] = spirv.Constant -1
122  // CANON: return %[[CST]]
123  %ones = spirv.Constant 255 : i8
124  %0 = spirv.BitwiseOr %arg, %ones : i8
125  return %0 : i8
126}
127
128// CANON-LABEL: func @bitwise_or_all_ones_vector
129func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> {
130  // CANON: %[[CST:.+]] = spirv.Constant dense<-1>
131  // CANON: return %[[CST]]
132  %ones = spirv.Constant dense<255> : vector<3xi8>
133  %0 = spirv.BitwiseOr %arg, %ones : vector<3xi8>
134  return %0 : vector<3xi8>
135}
136
137// -----
138
139func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 {
140  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
141  %0 = spirv.BitwiseOr %arg0, %arg1 : f16
142  return %0 : f16
143}
144
145// -----
146
147//===----------------------------------------------------------------------===//
148// spirv.BitwiseXor
149//===----------------------------------------------------------------------===//
150
151func.func @bitwise_xor_scalar(%arg: i32) -> i32 {
152  %c1 = spirv.Constant 1 : i32 // using constant to avoid folding
153  // CHECK: spirv.BitwiseXor
154  %0 = spirv.BitwiseXor %c1, %arg : i32
155  return %0 : i32
156}
157
158func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> {
159  %c1 = spirv.Constant dense<1> : vector<4xi32> // using constant to avoid folding
160  // CHECK: spirv.BitwiseXor
161  %0 = spirv.BitwiseXor %c1, %arg : vector<4xi32>
162  return %0 : vector<4xi32>
163}
164
165// -----
166
167func.func @bitwise_xor_float(%arg0: f16, %arg1: f16) -> f16 {
168  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
169  %0 = spirv.BitwiseXor %arg0, %arg1 : f16
170  return %0 : f16
171}
172
173// -----
174
175//===----------------------------------------------------------------------===//
176// spirv.BitwiseAnd
177//===----------------------------------------------------------------------===//
178
179// CHECK-LABEL: func @bitwise_and_scalar
180func.func @bitwise_and_scalar(%arg: i32) -> i32 {
181  // CHECK: spirv.BitwiseAnd
182  %0 = spirv.BitwiseAnd %arg, %arg : i32
183  return %0 : i32
184}
185
186// CHECK-LABEL: func @bitwise_and_vector
187func.func @bitwise_and_vector(%arg: vector<4xi32>) -> vector<4xi32> {
188  // CHECK: spirv.BitwiseAnd
189  %0 = spirv.BitwiseAnd %arg, %arg : vector<4xi32>
190  return %0 : vector<4xi32>
191}
192
193// CANON-LABEL: func @bitwise_and_zero
194func.func @bitwise_and_zero(%arg: i32) -> i32 {
195  // CANON: %[[CST:.+]] = spirv.Constant 0
196  // CANON: return %[[CST]]
197  %zero = spirv.Constant 0 : i32
198  %0 = spirv.BitwiseAnd %arg, %zero : i32
199  return %0 : i32
200}
201
202// CANON-LABEL: func @bitwise_and_zero_vector
203func.func @bitwise_and_zero_vector(%arg: vector<4xi32>) -> vector<4xi32> {
204  // CANON: %[[CST:.+]] = spirv.Constant dense<0>
205  // CANON: return %[[CST]]
206  %zero = spirv.Constant dense<0> : vector<4xi32>
207  %0 = spirv.BitwiseAnd %arg, %zero : vector<4xi32>
208  return %0 : vector<4xi32>
209}
210
211// CANON-LABEL: func @bitwise_and_all_ones
212// CANON-SAME:    (%[[ARG:.+]]: i8)
213func.func @bitwise_and_all_ones(%arg: i8) -> i8 {
214  // CANON: return %[[ARG]]
215  %ones = spirv.Constant 255 : i8
216  %0 = spirv.BitwiseAnd %arg, %ones : i8
217  return %0 : i8
218}
219
220// CANON-LABEL: func @bitwise_and_all_ones_vector
221// CANON-SAME:    (%[[ARG:.+]]: vector<3xi8>)
222func.func @bitwise_and_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> {
223  // CANON: return %[[ARG]]
224  %ones = spirv.Constant dense<255> : vector<3xi8>
225  %0 = spirv.BitwiseAnd %arg, %ones : vector<3xi8>
226  return %0 : vector<3xi8>
227}
228
229// CANON-LABEL: func @bitwise_and_zext_1
230// CANON-SAME:    (%[[ARG:.+]]: i8)
231func.func @bitwise_and_zext_1(%arg: i8) -> i32 {
232  // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]]
233  // CANON: return %[[ZEXT]]
234  %zext = spirv.UConvert %arg : i8 to i32
235  %ones = spirv.Constant 255 : i32
236  %0 = spirv.BitwiseAnd %zext, %ones : i32
237  return %0 : i32
238}
239
240// CANON-LABEL: func @bitwise_and_zext_2
241// CANON-SAME:    (%[[ARG:.+]]: i8)
242func.func @bitwise_and_zext_2(%arg: i8) -> i32 {
243  // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]]
244  // CANON: return %[[ZEXT]]
245  %zext = spirv.UConvert %arg : i8 to i32
246  %ones = spirv.Constant 0x12345ff : i32
247  %0 = spirv.BitwiseAnd %zext, %ones : i32
248  return %0 : i32
249}
250
251// CANON-LABEL: func @bitwise_and_zext_3
252// CANON-SAME:    (%[[ARG:.+]]: i8)
253func.func @bitwise_and_zext_3(%arg: i8) -> i32 {
254  // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]]
255  // CANON: %[[AND:.+]]  = spirv.BitwiseAnd %[[ZEXT]]
256  // CANON: return %[[AND]]
257  %zext = spirv.UConvert %arg : i8 to i32
258  %ones = spirv.Constant 254 : i32
259  %0 = spirv.BitwiseAnd %zext, %ones : i32
260  return %0 : i32
261}
262
263// CANON-LABEL: func @bitwise_and_zext_vector
264// CANON-SAME:    (%[[ARG:.+]]: vector<2xi8>)
265func.func @bitwise_and_zext_vector(%arg: vector<2xi8>) -> vector<2xi32> {
266  // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]]
267  // CANON: return %[[ZEXT]]
268  %zext = spirv.UConvert %arg : vector<2xi8> to vector<2xi32>
269  %ones = spirv.Constant dense<255> : vector<2xi32>
270  %0 = spirv.BitwiseAnd %zext, %ones : vector<2xi32>
271  return %0 : vector<2xi32>
272}
273
274// -----
275
276func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 {
277  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
278  %0 = spirv.BitwiseAnd %arg0, %arg1 : f16
279  return %0 : f16
280}
281
282// -----
283
284//===----------------------------------------------------------------------===//
285// spirv.Not
286//===----------------------------------------------------------------------===//
287
288func.func @not(%arg: i32) -> i32 {
289  // CHECK: spirv.Not {{%.*}} : i32
290  %0 = spirv.Not %arg : i32
291  spirv.ReturnValue %0 : i32
292}
293
294// -----
295
296//===----------------------------------------------------------------------===//
297// spirv.ShiftLeftLogical
298//===----------------------------------------------------------------------===//
299
300func.func @shift_left_logical(%arg0: i32, %arg1 : i16) -> i32 {
301  // CHECK: {{%.*}} = spirv.ShiftLeftLogical {{%.*}}, {{%.*}} : i32, i16
302  %0 = spirv.ShiftLeftLogical %arg0, %arg1: i32, i16
303  spirv.ReturnValue %0 : i32
304}
305
306// -----
307
308func.func @shift_left_logical_invalid_result_type(%arg0: i32, %arg1 : i16) -> i16 {
309  // expected-error @+1 {{op failed to verify that all of {operand1, result} have same type}}
310  %0 = "spirv.ShiftLeftLogical" (%arg0, %arg1) : (i32, i16) -> (i16)
311  spirv.ReturnValue %0 : i16
312}
313
314// -----
315
316//===----------------------------------------------------------------------===//
317// spirv.ShiftRightArithmetic
318//===----------------------------------------------------------------------===//
319
320func.func @shift_right_arithmetic(%arg0: vector<4xi32>, %arg1 : vector<4xi8>) -> vector<4xi32> {
321  // CHECK: {{%.*}} = spirv.ShiftRightArithmetic {{%.*}}, {{%.*}} : vector<4xi32>, vector<4xi8>
322  %0 = spirv.ShiftRightArithmetic %arg0, %arg1: vector<4xi32>, vector<4xi8>
323  spirv.ReturnValue %0 : vector<4xi32>
324}
325
326// -----
327
328//===----------------------------------------------------------------------===//
329// spirv.ShiftRightLogical
330//===----------------------------------------------------------------------===//
331
332func.func @shift_right_logical(%arg0: vector<2xi32>, %arg1 : vector<2xi8>) -> vector<2xi32> {
333  // CHECK: {{%.*}} = spirv.ShiftRightLogical {{%.*}}, {{%.*}} : vector<2xi32>, vector<2xi8>
334  %0 = spirv.ShiftRightLogical %arg0, %arg1: vector<2xi32>, vector<2xi8>
335  spirv.ReturnValue %0 : vector<2xi32>
336}
337