xref: /llvm-project/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir (revision eb206e9ea84eff0a0596fed2de8316d924f946d1)
1// RUN: mlir-opt %s \
2// RUN:   -one-shot-bufferize="bufferize-function-boundaries" --canonicalize \
3// RUN:   -convert-scf-to-cf --convert-complex-to-standard \
4// RUN:   -finalize-memref-to-llvm -convert-math-to-llvm -convert-math-to-libm \
5// RUN:   -convert-vector-to-llvm -convert-complex-to-llvm \
6// RUN:   -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm \
7// RUN:   -reconcile-unrealized-casts |\
8// RUN: mlir-runner \
9// RUN:  -e entry -entry-point-result=void  \
10// RUN:  -shared-libs=%mlir_c_runner_utils |\
11// RUN: FileCheck %s
12
13func.func @test_unary(%input: tensor<?xcomplex<f32>>,
14                      %func: (complex<f32>) -> complex<f32>) {
15  %c0 = arith.constant 0 : index
16  %c1 = arith.constant 1 : index
17  %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>>
18
19  scf.for %i = %c0 to %size step %c1 {
20    %elem = tensor.extract %input[%i]: tensor<?xcomplex<f32>>
21
22    %val = func.call_indirect %func(%elem) : (complex<f32>) -> complex<f32>
23    %real = complex.re %val : complex<f32>
24    %imag = complex.im %val: complex<f32>
25    vector.print %real : f32
26    vector.print %imag : f32
27    scf.yield
28  }
29  func.return
30}
31
32func.func @sqrt(%arg: complex<f32>) -> complex<f32> {
33  %sqrt = complex.sqrt %arg : complex<f32>
34  func.return %sqrt : complex<f32>
35}
36
37func.func @tanh(%arg: complex<f32>) -> complex<f32> {
38  %tanh = complex.tanh %arg : complex<f32>
39  func.return %tanh : complex<f32>
40}
41
42func.func @rsqrt(%arg: complex<f32>) -> complex<f32> {
43  %sqrt = complex.rsqrt %arg : complex<f32>
44  func.return %sqrt : complex<f32>
45}
46
47func.func @conj(%arg: complex<f32>) -> complex<f32> {
48  %conj = complex.conj %arg : complex<f32>
49  func.return %conj : complex<f32>
50}
51
52// %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...]
53func.func @test_binary(%input: tensor<?xcomplex<f32>>,
54                       %func: (complex<f32>, complex<f32>) -> complex<f32>) {
55  %c0 = arith.constant 0 : index
56  %c1 = arith.constant 1 : index
57  %c2 = arith.constant 2 : index
58  %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>>
59
60  scf.for %i = %c0 to %size step %c2 {
61    %lhs = tensor.extract %input[%i]: tensor<?xcomplex<f32>>
62    %i_next = arith.addi %i, %c1 : index
63    %rhs = tensor.extract %input[%i_next]: tensor<?xcomplex<f32>>
64
65    %val = func.call_indirect %func(%lhs, %rhs)
66      : (complex<f32>, complex<f32>) -> complex<f32>
67    %real = complex.re %val : complex<f32>
68    %imag = complex.im %val: complex<f32>
69    vector.print %real : f32
70    vector.print %imag : f32
71    scf.yield
72  }
73  func.return
74}
75
76func.func @atan2(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
77  %atan2 = complex.atan2 %lhs, %rhs : complex<f32>
78  func.return %atan2 : complex<f32>
79}
80
81func.func @pow(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
82  %pow = complex.pow %lhs, %rhs : complex<f32>
83  func.return %pow : complex<f32>
84}
85
86func.func @test_element(%input: tensor<?xcomplex<f32>>,
87                      %func: (complex<f32>) -> f32) {
88  %c0 = arith.constant 0 : index
89  %c1 = arith.constant 1 : index
90  %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>>
91
92  scf.for %i = %c0 to %size step %c1 {
93    %elem = tensor.extract %input[%i]: tensor<?xcomplex<f32>>
94
95    %val = func.call_indirect %func(%elem) : (complex<f32>) -> f32
96    vector.print %val : f32
97    scf.yield
98  }
99  func.return
100}
101
102func.func @angle(%arg: complex<f32>) -> f32 {
103  %angle = complex.angle %arg : complex<f32>
104  func.return %angle : f32
105}
106
107func.func @test_element_f64(%input: tensor<?xcomplex<f64>>,
108                      %func: (complex<f64>) -> f64) {
109  %c0 = arith.constant 0 : index
110  %c1 = arith.constant 1 : index
111  %size = tensor.dim %input, %c0: tensor<?xcomplex<f64>>
112
113  scf.for %i = %c0 to %size step %c1 {
114    %elem = tensor.extract %input[%i]: tensor<?xcomplex<f64>>
115
116    %val = func.call_indirect %func(%elem) : (complex<f64>) -> f64
117    vector.print %val : f64
118    scf.yield
119  }
120  func.return
121}
122
123func.func @abs(%arg: complex<f64>) -> f64 {
124  %abs = complex.abs %arg : complex<f64>
125  func.return %abs : f64
126}
127
128func.func @entry() {
129  // complex.sqrt test
130  %sqrt_test = arith.constant dense<[
131    (-1.0, -1.0),
132    // CHECK:       0.455
133    // CHECK-NEXT: -1.098
134    (-1.0, 1.0),
135    // CHECK-NEXT:  0.455
136    // CHECK-NEXT:  1.098
137    (0.0, 0.0),
138    // CHECK-NEXT:  0
139    // CHECK-NEXT:  0
140    (0.0, 1.0),
141    // CHECK-NEXT:  0.707
142    // CHECK-NEXT:  0.707
143    (1.0, -1.0),
144    // CHECK-NEXT:  1.098
145    // CHECK-NEXT:  -0.455
146    (1.0, 0.0),
147    // CHECK-NEXT:  1
148    // CHECK-NEXT:  0
149    (1.0, 1.0)
150    // CHECK-NEXT:  1.098
151    // CHECK-NEXT:  0.455
152  ]> : tensor<7xcomplex<f32>>
153  %sqrt_test_cast = tensor.cast %sqrt_test
154    :  tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
155
156  %sqrt_func = func.constant @sqrt : (complex<f32>) -> complex<f32>
157  call @test_unary(%sqrt_test_cast, %sqrt_func)
158    : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
159
160  // complex.atan2 test
161  %atan2_test = arith.constant dense<[
162    (1.0, 2.0), (2.0, 1.0),
163    // CHECK:       0.785
164    // CHECK-NEXT:  0.346
165    (1.0, 1.0), (1.0, 0.0),
166    // CHECK-NEXT:  1.017
167    // CHECK-NEXT:  0.402
168    (1.0, 1.0), (1.0, 1.0)
169    // CHECK-NEXT:  0.785
170    // CHECK-NEXT:  0
171  ]> : tensor<6xcomplex<f32>>
172  %atan2_test_cast = tensor.cast %atan2_test
173    :  tensor<6xcomplex<f32>> to tensor<?xcomplex<f32>>
174
175  %atan2_func = func.constant @atan2 : (complex<f32>, complex<f32>)
176    -> complex<f32>
177  call @test_binary(%atan2_test_cast, %atan2_func)
178    : (tensor<?xcomplex<f32>>, (complex<f32>, complex<f32>)
179    -> complex<f32>) -> ()
180
181  // complex.pow test
182  %pow_test = arith.constant dense<[
183    (0.0, 0.0), (0.0, 0.0),
184    // CHECK:       1
185    // CHECK-NEXT:  0
186    (0.0, 0.0), (1.0, 0.0),
187    // CHECK-NEXT:  0
188    // CHECK-NEXT:  0
189    (0.0, 0.0), (-1.0, 0.0),
190    // Ignoring the sign of nan as that can't be tested in platform agnostic manner. See: #58531
191    // CHECK-NEXT:  nan
192    // CHECK-NEXT:  nan
193    (1.0, 1.0), (1.0, 1.0)
194    // CHECK-NEXT:  0.273
195    // CHECK-NEXT:  0.583
196  ]> : tensor<8xcomplex<f32>>
197  %pow_test_cast = tensor.cast %pow_test
198    :  tensor<8xcomplex<f32>> to tensor<?xcomplex<f32>>
199
200  %pow_func = func.constant @pow : (complex<f32>, complex<f32>)
201    -> complex<f32>
202  call @test_binary(%pow_test_cast, %pow_func)
203    : (tensor<?xcomplex<f32>>, (complex<f32>, complex<f32>)
204    -> complex<f32>) -> ()
205
206  // complex.tanh test
207  %tanh_test = arith.constant dense<[
208    (-1.0, -1.0),
209    // CHECK:      -1.08392
210    // CHECK-NEXT: -0.271753
211    (-1.0, 1.0),
212    // CHECK-NEXT:  -1.08392
213    // CHECK-NEXT:  0.271753
214    (0.0, 0.0),
215    // CHECK-NEXT:  0
216    // CHECK-NEXT:  0
217    (0.0, 1.0),
218    // CHECK-NEXT:  0
219    // CHECK-NEXT:  1.5574
220    (1.0, -1.0),
221    // CHECK-NEXT:  1.08392
222    // CHECK-NEXT:  -0.271753
223    (1.0, 0.0),
224    // CHECK-NEXT:  0.761594
225    // CHECK-NEXT:  0
226    (1.0, 1.0)
227    // CHECK-NEXT:  1.08392
228    // CHECK-NEXT:  0.271753
229  ]> : tensor<7xcomplex<f32>>
230  %tanh_test_cast = tensor.cast %tanh_test
231    :  tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
232
233  %tanh_func = func.constant @tanh : (complex<f32>) -> complex<f32>
234  call @test_unary(%tanh_test_cast, %tanh_func)
235    : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
236
237  // complex.rsqrt test
238  %rsqrt_test = arith.constant dense<[
239    (-1.0, -1.0),
240    // CHECK:       0.321
241    // CHECK-NEXT:  0.776
242    (-1.0, 1.0),
243    // CHECK-NEXT:  0.321
244    // CHECK-NEXT:  -0.776
245    (0.0, 0.0),
246    // CHECK-NEXT:  inf
247    // CHECK-NEXT:  nan
248    (0.0, 1.0),
249    // CHECK-NEXT:  0.707
250    // CHECK-NEXT:  -0.707
251    (1.0, -1.0),
252    // CHECK-NEXT:  0.776
253    // CHECK-NEXT:  0.321
254    (1.0, 0.0),
255    // CHECK-NEXT:  1
256    // CHECK-NEXT:  0
257    (1.0, 1.0)
258    // CHECK-NEXT:  0.776
259    // CHECK-NEXT:  -0.321
260  ]> : tensor<7xcomplex<f32>>
261  %rsqrt_test_cast = tensor.cast %rsqrt_test
262    :  tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
263
264  %rsqrt_func = func.constant @rsqrt : (complex<f32>) -> complex<f32>
265  call @test_unary(%rsqrt_test_cast, %rsqrt_func)
266    : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
267
268  // complex.conj test
269  %conj_test = arith.constant dense<[
270    (-1.0, -1.0),
271    // CHECK:      -1
272    // CHECK-NEXT: 1
273    (-1.0, 1.0),
274    // CHECK-NEXT:  -1
275    // CHECK-NEXT:  -1
276    (0.0, 0.0),
277    // CHECK-NEXT:  0
278    // CHECK-NEXT:  0
279    (0.0, 1.0),
280    // CHECK-NEXT:  0
281    // CHECK-NEXT:  -1
282    (1.0, -1.0),
283    // CHECK-NEXT:  1
284    // CHECK-NEXT:  1
285    (1.0, 0.0),
286    // CHECK-NEXT:  1
287    // CHECK-NEXT:  0
288    (1.0, 1.0)
289    // CHECK-NEXT:  1
290    // CHECK-NEXT:  -1
291  ]> : tensor<7xcomplex<f32>>
292  %conj_test_cast = tensor.cast %conj_test
293    :  tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
294
295  %conj_func = func.constant @conj : (complex<f32>) -> complex<f32>
296  call @test_unary(%conj_test_cast, %conj_func)
297    : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
298
299  // complex.angle test
300  %angle_test = arith.constant dense<[
301    (-1.0, -1.0),
302    // CHECK:      -2.356
303    (-1.0, 1.0),
304    // CHECK-NEXT:  2.356
305    (0.0, 0.0),
306    // CHECK-NEXT:  0
307    (0.0, 1.0),
308    // CHECK-NEXT:  1.570
309    (1.0, -1.0),
310    // CHECK-NEXT:  -0.785
311    (1.0, 0.0),
312    // CHECK-NEXT:  0
313    (1.0, 1.0)
314    // CHECK-NEXT:  0.785
315  ]> : tensor<7xcomplex<f32>>
316  %angle_test_cast = tensor.cast %angle_test
317    :  tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
318
319  %angle_func = func.constant @angle : (complex<f32>) -> f32
320  call @test_element(%angle_test_cast, %angle_func)
321    : (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> ()
322
323  // complex.abs test
324  %abs_test = arith.constant dense<[
325    (1.0, 1.0),
326    // CHECK:  1.414
327    (1.0e300, 1.0e300),
328    // CHECK-NEXT:  1.41421e+300
329    (1.0e-300, 1.0e-300),
330    // CHECK-NEXT:  1.41421e-300
331    (5.0, 0.0),
332    // CHECK-NEXT:  5
333    (0.0, 6.0),
334    // CHECK-NEXT:  6
335    (7.0, 8.0),
336    // CHECK-NEXT:  10.6301
337    (-1.0, -1.0),
338    // CHECK-NEXT: 1.414
339    (-1.0e300, -1.0e300),
340    // CHECK-NEXT:  1.41421e+300
341    (-1.0, 0.0),
342    // CHECK-NOT: -1
343    // CHECK-NEXT:  1
344    (0.0, -1.0)
345    // CHECK-NOT:  -1
346    // CHECK-NEXT:  1
347  ]> : tensor<10xcomplex<f64>>
348  %abs_test_cast = tensor.cast %abs_test
349    :  tensor<10xcomplex<f64>> to tensor<?xcomplex<f64>>
350
351  %abs_func = func.constant @abs : (complex<f64>) -> f64
352
353  call @test_element_f64(%abs_test_cast, %abs_func)
354    : (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> ()
355
356  func.return
357}
358