1// RUN: mlir-opt %s \ 2// RUN: -one-shot-bufferize="bufferize-function-boundaries" --canonicalize \ 3// RUN: -convert-scf-to-cf --convert-complex-to-standard \ 4// RUN: -finalize-memref-to-llvm -convert-math-to-llvm -convert-math-to-libm \ 5// RUN: -convert-vector-to-llvm -convert-complex-to-llvm \ 6// RUN: -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm \ 7// RUN: -reconcile-unrealized-casts |\ 8// RUN: mlir-runner \ 9// RUN: -e entry -entry-point-result=void \ 10// RUN: -shared-libs=%mlir_c_runner_utils |\ 11// RUN: FileCheck %s 12 13func.func @test_unary(%input: tensor<?xcomplex<f32>>, 14 %func: (complex<f32>) -> complex<f32>) { 15 %c0 = arith.constant 0 : index 16 %c1 = arith.constant 1 : index 17 %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>> 18 19 scf.for %i = %c0 to %size step %c1 { 20 %elem = tensor.extract %input[%i]: tensor<?xcomplex<f32>> 21 22 %val = func.call_indirect %func(%elem) : (complex<f32>) -> complex<f32> 23 %real = complex.re %val : complex<f32> 24 %imag = complex.im %val: complex<f32> 25 vector.print %real : f32 26 vector.print %imag : f32 27 scf.yield 28 } 29 func.return 30} 31 32func.func @sqrt(%arg: complex<f32>) -> complex<f32> { 33 %sqrt = complex.sqrt %arg : complex<f32> 34 func.return %sqrt : complex<f32> 35} 36 37func.func @tanh(%arg: complex<f32>) -> complex<f32> { 38 %tanh = complex.tanh %arg : complex<f32> 39 func.return %tanh : complex<f32> 40} 41 42func.func @rsqrt(%arg: complex<f32>) -> complex<f32> { 43 %sqrt = complex.rsqrt %arg : complex<f32> 44 func.return %sqrt : complex<f32> 45} 46 47func.func @conj(%arg: complex<f32>) -> complex<f32> { 48 %conj = complex.conj %arg : complex<f32> 49 func.return %conj : complex<f32> 50} 51 52// %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...] 53func.func @test_binary(%input: tensor<?xcomplex<f32>>, 54 %func: (complex<f32>, complex<f32>) -> complex<f32>) { 55 %c0 = arith.constant 0 : index 56 %c1 = arith.constant 1 : index 57 %c2 = arith.constant 2 : index 58 %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>> 59 60 scf.for %i = %c0 to %size step %c2 { 61 %lhs = tensor.extract %input[%i]: tensor<?xcomplex<f32>> 62 %i_next = arith.addi %i, %c1 : index 63 %rhs = tensor.extract %input[%i_next]: tensor<?xcomplex<f32>> 64 65 %val = func.call_indirect %func(%lhs, %rhs) 66 : (complex<f32>, complex<f32>) -> complex<f32> 67 %real = complex.re %val : complex<f32> 68 %imag = complex.im %val: complex<f32> 69 vector.print %real : f32 70 vector.print %imag : f32 71 scf.yield 72 } 73 func.return 74} 75 76func.func @atan2(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> { 77 %atan2 = complex.atan2 %lhs, %rhs : complex<f32> 78 func.return %atan2 : complex<f32> 79} 80 81func.func @pow(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> { 82 %pow = complex.pow %lhs, %rhs : complex<f32> 83 func.return %pow : complex<f32> 84} 85 86func.func @test_element(%input: tensor<?xcomplex<f32>>, 87 %func: (complex<f32>) -> f32) { 88 %c0 = arith.constant 0 : index 89 %c1 = arith.constant 1 : index 90 %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>> 91 92 scf.for %i = %c0 to %size step %c1 { 93 %elem = tensor.extract %input[%i]: tensor<?xcomplex<f32>> 94 95 %val = func.call_indirect %func(%elem) : (complex<f32>) -> f32 96 vector.print %val : f32 97 scf.yield 98 } 99 func.return 100} 101 102func.func @angle(%arg: complex<f32>) -> f32 { 103 %angle = complex.angle %arg : complex<f32> 104 func.return %angle : f32 105} 106 107func.func @test_element_f64(%input: tensor<?xcomplex<f64>>, 108 %func: (complex<f64>) -> f64) { 109 %c0 = arith.constant 0 : index 110 %c1 = arith.constant 1 : index 111 %size = tensor.dim %input, %c0: tensor<?xcomplex<f64>> 112 113 scf.for %i = %c0 to %size step %c1 { 114 %elem = tensor.extract %input[%i]: tensor<?xcomplex<f64>> 115 116 %val = func.call_indirect %func(%elem) : (complex<f64>) -> f64 117 vector.print %val : f64 118 scf.yield 119 } 120 func.return 121} 122 123func.func @abs(%arg: complex<f64>) -> f64 { 124 %abs = complex.abs %arg : complex<f64> 125 func.return %abs : f64 126} 127 128func.func @entry() { 129 // complex.sqrt test 130 %sqrt_test = arith.constant dense<[ 131 (-1.0, -1.0), 132 // CHECK: 0.455 133 // CHECK-NEXT: -1.098 134 (-1.0, 1.0), 135 // CHECK-NEXT: 0.455 136 // CHECK-NEXT: 1.098 137 (0.0, 0.0), 138 // CHECK-NEXT: 0 139 // CHECK-NEXT: 0 140 (0.0, 1.0), 141 // CHECK-NEXT: 0.707 142 // CHECK-NEXT: 0.707 143 (1.0, -1.0), 144 // CHECK-NEXT: 1.098 145 // CHECK-NEXT: -0.455 146 (1.0, 0.0), 147 // CHECK-NEXT: 1 148 // CHECK-NEXT: 0 149 (1.0, 1.0) 150 // CHECK-NEXT: 1.098 151 // CHECK-NEXT: 0.455 152 ]> : tensor<7xcomplex<f32>> 153 %sqrt_test_cast = tensor.cast %sqrt_test 154 : tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>> 155 156 %sqrt_func = func.constant @sqrt : (complex<f32>) -> complex<f32> 157 call @test_unary(%sqrt_test_cast, %sqrt_func) 158 : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> () 159 160 // complex.atan2 test 161 %atan2_test = arith.constant dense<[ 162 (1.0, 2.0), (2.0, 1.0), 163 // CHECK: 0.785 164 // CHECK-NEXT: 0.346 165 (1.0, 1.0), (1.0, 0.0), 166 // CHECK-NEXT: 1.017 167 // CHECK-NEXT: 0.402 168 (1.0, 1.0), (1.0, 1.0) 169 // CHECK-NEXT: 0.785 170 // CHECK-NEXT: 0 171 ]> : tensor<6xcomplex<f32>> 172 %atan2_test_cast = tensor.cast %atan2_test 173 : tensor<6xcomplex<f32>> to tensor<?xcomplex<f32>> 174 175 %atan2_func = func.constant @atan2 : (complex<f32>, complex<f32>) 176 -> complex<f32> 177 call @test_binary(%atan2_test_cast, %atan2_func) 178 : (tensor<?xcomplex<f32>>, (complex<f32>, complex<f32>) 179 -> complex<f32>) -> () 180 181 // complex.pow test 182 %pow_test = arith.constant dense<[ 183 (0.0, 0.0), (0.0, 0.0), 184 // CHECK: 1 185 // CHECK-NEXT: 0 186 (0.0, 0.0), (1.0, 0.0), 187 // CHECK-NEXT: 0 188 // CHECK-NEXT: 0 189 (0.0, 0.0), (-1.0, 0.0), 190 // Ignoring the sign of nan as that can't be tested in platform agnostic manner. See: #58531 191 // CHECK-NEXT: nan 192 // CHECK-NEXT: nan 193 (1.0, 1.0), (1.0, 1.0) 194 // CHECK-NEXT: 0.273 195 // CHECK-NEXT: 0.583 196 ]> : tensor<8xcomplex<f32>> 197 %pow_test_cast = tensor.cast %pow_test 198 : tensor<8xcomplex<f32>> to tensor<?xcomplex<f32>> 199 200 %pow_func = func.constant @pow : (complex<f32>, complex<f32>) 201 -> complex<f32> 202 call @test_binary(%pow_test_cast, %pow_func) 203 : (tensor<?xcomplex<f32>>, (complex<f32>, complex<f32>) 204 -> complex<f32>) -> () 205 206 // complex.tanh test 207 %tanh_test = arith.constant dense<[ 208 (-1.0, -1.0), 209 // CHECK: -1.08392 210 // CHECK-NEXT: -0.271753 211 (-1.0, 1.0), 212 // CHECK-NEXT: -1.08392 213 // CHECK-NEXT: 0.271753 214 (0.0, 0.0), 215 // CHECK-NEXT: 0 216 // CHECK-NEXT: 0 217 (0.0, 1.0), 218 // CHECK-NEXT: 0 219 // CHECK-NEXT: 1.5574 220 (1.0, -1.0), 221 // CHECK-NEXT: 1.08392 222 // CHECK-NEXT: -0.271753 223 (1.0, 0.0), 224 // CHECK-NEXT: 0.761594 225 // CHECK-NEXT: 0 226 (1.0, 1.0) 227 // CHECK-NEXT: 1.08392 228 // CHECK-NEXT: 0.271753 229 ]> : tensor<7xcomplex<f32>> 230 %tanh_test_cast = tensor.cast %tanh_test 231 : tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>> 232 233 %tanh_func = func.constant @tanh : (complex<f32>) -> complex<f32> 234 call @test_unary(%tanh_test_cast, %tanh_func) 235 : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> () 236 237 // complex.rsqrt test 238 %rsqrt_test = arith.constant dense<[ 239 (-1.0, -1.0), 240 // CHECK: 0.321 241 // CHECK-NEXT: 0.776 242 (-1.0, 1.0), 243 // CHECK-NEXT: 0.321 244 // CHECK-NEXT: -0.776 245 (0.0, 0.0), 246 // CHECK-NEXT: inf 247 // CHECK-NEXT: nan 248 (0.0, 1.0), 249 // CHECK-NEXT: 0.707 250 // CHECK-NEXT: -0.707 251 (1.0, -1.0), 252 // CHECK-NEXT: 0.776 253 // CHECK-NEXT: 0.321 254 (1.0, 0.0), 255 // CHECK-NEXT: 1 256 // CHECK-NEXT: 0 257 (1.0, 1.0) 258 // CHECK-NEXT: 0.776 259 // CHECK-NEXT: -0.321 260 ]> : tensor<7xcomplex<f32>> 261 %rsqrt_test_cast = tensor.cast %rsqrt_test 262 : tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>> 263 264 %rsqrt_func = func.constant @rsqrt : (complex<f32>) -> complex<f32> 265 call @test_unary(%rsqrt_test_cast, %rsqrt_func) 266 : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> () 267 268 // complex.conj test 269 %conj_test = arith.constant dense<[ 270 (-1.0, -1.0), 271 // CHECK: -1 272 // CHECK-NEXT: 1 273 (-1.0, 1.0), 274 // CHECK-NEXT: -1 275 // CHECK-NEXT: -1 276 (0.0, 0.0), 277 // CHECK-NEXT: 0 278 // CHECK-NEXT: 0 279 (0.0, 1.0), 280 // CHECK-NEXT: 0 281 // CHECK-NEXT: -1 282 (1.0, -1.0), 283 // CHECK-NEXT: 1 284 // CHECK-NEXT: 1 285 (1.0, 0.0), 286 // CHECK-NEXT: 1 287 // CHECK-NEXT: 0 288 (1.0, 1.0) 289 // CHECK-NEXT: 1 290 // CHECK-NEXT: -1 291 ]> : tensor<7xcomplex<f32>> 292 %conj_test_cast = tensor.cast %conj_test 293 : tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>> 294 295 %conj_func = func.constant @conj : (complex<f32>) -> complex<f32> 296 call @test_unary(%conj_test_cast, %conj_func) 297 : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> () 298 299 // complex.angle test 300 %angle_test = arith.constant dense<[ 301 (-1.0, -1.0), 302 // CHECK: -2.356 303 (-1.0, 1.0), 304 // CHECK-NEXT: 2.356 305 (0.0, 0.0), 306 // CHECK-NEXT: 0 307 (0.0, 1.0), 308 // CHECK-NEXT: 1.570 309 (1.0, -1.0), 310 // CHECK-NEXT: -0.785 311 (1.0, 0.0), 312 // CHECK-NEXT: 0 313 (1.0, 1.0) 314 // CHECK-NEXT: 0.785 315 ]> : tensor<7xcomplex<f32>> 316 %angle_test_cast = tensor.cast %angle_test 317 : tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>> 318 319 %angle_func = func.constant @angle : (complex<f32>) -> f32 320 call @test_element(%angle_test_cast, %angle_func) 321 : (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> () 322 323 // complex.abs test 324 %abs_test = arith.constant dense<[ 325 (1.0, 1.0), 326 // CHECK: 1.414 327 (1.0e300, 1.0e300), 328 // CHECK-NEXT: 1.41421e+300 329 (1.0e-300, 1.0e-300), 330 // CHECK-NEXT: 1.41421e-300 331 (5.0, 0.0), 332 // CHECK-NEXT: 5 333 (0.0, 6.0), 334 // CHECK-NEXT: 6 335 (7.0, 8.0), 336 // CHECK-NEXT: 10.6301 337 (-1.0, -1.0), 338 // CHECK-NEXT: 1.414 339 (-1.0e300, -1.0e300), 340 // CHECK-NEXT: 1.41421e+300 341 (-1.0, 0.0), 342 // CHECK-NOT: -1 343 // CHECK-NEXT: 1 344 (0.0, -1.0) 345 // CHECK-NOT: -1 346 // CHECK-NEXT: 1 347 ]> : tensor<10xcomplex<f64>> 348 %abs_test_cast = tensor.cast %abs_test 349 : tensor<10xcomplex<f64>> to tensor<?xcomplex<f64>> 350 351 %abs_func = func.constant @abs : (complex<f64>) -> f64 352 353 call @test_element_f64(%abs_test_cast, %abs_func) 354 : (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> () 355 356 func.return 357} 358