1# RUN: %PYTHON %s 2>&1 | FileCheck %s 2 3import ctypes 4import sys 5from mlir.ir import * 6from mlir.dialects import builtin 7from mlir.dialects import func 8from mlir.dialects import linalg 9from mlir.passmanager import * 10from mlir.execution_engine import * 11 12from mlir.dialects.linalg.opdsl.lang import * 13 14 15# Log everything to stderr and flush so that we have a unified stream to match 16# errors/info emitted by MLIR to stderr. 17def log(*args): 18 print(*args, file=sys.stderr) 19 sys.stderr.flush() 20 21 22elemwise_boiler = """ 23func.func @main() -> f32 attributes {llvm.emit_c_interface} { 24 %v0 = arith.constant 0.0 : f32 25 %v1 = arith.constant 1.0 : f32 26 %v2 = arith.constant 2.0 : f32 27 28 %lhs = memref.alloc() : memref<f32> 29 %rhs = memref.alloc() : memref<4x8xf32> 30 %O0 = memref.alloc() : memref<4x8xf32> 31 %O1 = memref.alloc() : memref<4x8xf32> 32 linalg.fill ins(%v1 : f32) outs(%lhs : memref<f32>) 33 linalg.fill ins(%v2 : f32) outs(%rhs : memref<4x8xf32>) 34 linalg.fill ins(%v0 : f32) outs(%O0 : memref<4x8xf32>) 35 linalg.fill ins(%v0 : f32) outs(%O1 : memref<4x8xf32>) 36 37 call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) : 38 (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> () 39 call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) : 40 (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> () 41 42 %c0 = arith.constant 0 : index 43 %res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32> 44 %res1 = memref.load %O1[%c0, %c0] : memref<4x8xf32> 45 46 %0 = arith.addf %res0, %res1 : f32 47 48 // TODO: FFI-based solution to allow testing and printing with python code. 49 return %0 : f32 50} 51""" 52 53fill_boiler = """ 54func.func @main() -> i32 attributes {llvm.emit_c_interface} { 55 %O0 = memref.alloc() : memref<i32> 56 %O1 = memref.alloc() : memref<16xi32> 57 %O2 = memref.alloc() : memref<4x16xi32> 58 59 %val0 = arith.constant 1.0 : f32 60 %val1 = arith.constant 2.0 : f32 61 %val2 = arith.constant 3.0 : f32 62 63 call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> () 64 call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> () 65 call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> () 66 67 %c0 = arith.constant 0 : index 68 %res0 = memref.load %O0[] : memref<i32> 69 %c8 = arith.constant 8 : index 70 %res1 = memref.load %O1[%c8] : memref<16xi32> 71 %c2 = arith.constant 2 : index 72 %res2 = memref.load %O2[%c2, %c8] : memref<4x16xi32> 73 74 %0 = arith.addi %res0, %res1 : i32 75 %1 = arith.addi %0, %res2 : i32 76 77 // TODO: FFI-based solution to allow testing and printing with python code. 78 return %1 : i32 79} 80""" 81 82fill_rng_boiler = """ 83func.func @main() -> i32 attributes {llvm.emit_c_interface} { 84 %O = memref.alloc() : memref<4x16xi32> 85 %min = arith.constant -1000.0 : f64 86 %max = arith.constant 1000.0 : f64 87 %seed = arith.constant 42 : i32 88 89 call @fill_rng_on_buffers(%min, %max, %seed, %O) : 90 (f64, f64, i32, memref<4x16xi32>) -> () 91 92 %c0 = arith.constant 0 : index 93 %0 = memref.load %O[%c0, %c0] : memref<4x16xi32> 94 95 // TODO: FFI-based solution to allow testing and printing with python code. 96 return %0 : i32 97} 98""" 99 100conv_boiler = """ 101func.func @main() -> i32 attributes {llvm.emit_c_interface} { 102 %v0 = arith.constant 0 : i32 103 %v1 = arith.constant 1.0 : f64 104 %v2 = arith.constant 2.0 : f64 105 106 %input = memref.alloc() : memref<1x4x16x1xf64> 107 %filter = memref.alloc() : memref<2x2x1xf64> 108 %output = memref.alloc() : memref<1x2x4x1xi32> 109 linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>) 110 linalg.fill ins(%v2 : f64) outs(%filter : memref<2x2x1xf64>) 111 linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>) 112 113 call @conv_on_buffers(%input, %filter, %output) : 114 (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> () 115 116 %c0 = arith.constant 0 : index 117 %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32> 118 119 // TODO: FFI-based solution to allow testing and printing with python code. 120 return %0 : i32 121} 122""" 123 124pooling_boiler = """ 125func.func @main() -> i32 attributes {llvm.emit_c_interface} { 126 %v0 = arith.constant 0 : i32 127 %v42 = arith.constant 42.0 : f64 128 %v77 = arith.constant 77.0 : f64 129 %v-13 = arith.constant -13.0 : f64 130 %v1 = arith.constant 1.0 : f64 131 132 %input = memref.alloc() : memref<1x4x16x1xf64> 133 %shape = memref.alloc() : memref<2x2xf64> 134 %output = memref.alloc() : memref<1x2x4x1xi32> 135 linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>) 136 linalg.fill ins(%v1 : f64) outs(%shape : memref<2x2xf64>) 137 linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>) 138 139 %c0 = arith.constant 0 : index 140 %c1 = arith.constant 1 : index 141 %c2 = arith.constant 2 : index 142 memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64> 143 memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64> 144 memref.store %v-13, %input[%c0, %c1, %c0, %c0] : memref<1x4x16x1xf64> 145 146 call @pooling_on_buffers(%input, %shape, %output) : 147 (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> () 148 149 %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32> 150 151 // TODO: FFI-based solution to allow testing and printing with python code. 152 return %0 : i32 153} 154""" 155 156 157def transform(module, boilerplate): 158 # TODO: Allow cloning functions from one module to another. 159 # Atm we have to resort to string concatenation. 160 ops = module.operation.regions[0].blocks[0].operations 161 mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate) 162 163 pm = PassManager("builtin.module") 164 pm.add("func.func(convert-linalg-to-loops)") 165 pm.add("func.func(lower-affine)") 166 pm.add("func.func(convert-math-to-llvm)") 167 pm.add("func.func(convert-scf-to-cf)") 168 pm.add("func.func(arith-expand)") 169 pm.add("func.func(memref-expand)") 170 pm.add("convert-vector-to-llvm") 171 pm.add("finalize-memref-to-llvm") 172 pm.add("convert-func-to-llvm") 173 pm.add("convert-arith-to-llvm") 174 pm.add("convert-cf-to-llvm") 175 pm.add("reconcile-unrealized-casts") 176 pm.run(mod.operation) 177 return mod 178 179 180def test_elemwise_builtin(): 181 with Context() as ctx, Location.unknown(): 182 module = Module.create() 183 f32 = F32Type.get() 184 i8 = IntegerType.get_signless(8) 185 with InsertionPoint(module.body): 186 187 @func.FuncOp.from_py_func( 188 MemRefType.get((), f32), 189 MemRefType.get((4, 8), f32), 190 MemRefType.get((4, 8), f32), 191 ) 192 def elemwise_exp_add_on_buffers(lhs, rhs, out): 193 linalg.elemwise_unary(lhs, outs=[out]) 194 linalg.elemwise_binary(out, rhs, outs=[out]) 195 196 @func.FuncOp.from_py_func( 197 MemRefType.get((), f32), 198 MemRefType.get((4, 8), f32), 199 MemRefType.get((4, 8), f32), 200 ) 201 def elemwise_log_mul_on_buffers(lhs, rhs, out): 202 linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log) 203 linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul) 204 205 execution_engine = ExecutionEngine(transform(module, elemwise_boiler)) 206 207 # TODO: FFI-based solution to allow testing and printing with python code. 208 # Prepare arguments: one result f32. 209 # Arguments must be passed as pointers. 210 c_float_p = ctypes.c_float * 1 211 res = c_float_p(-1.0) 212 execution_engine.invoke("main", res) 213 214 log("RESULT: ", res[0]) 215 # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846 216 # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0 217 # CHECK: RESULT: 4.71828 218 219 220test_elemwise_builtin() 221 222 223def test_elemwise_generic(): 224 with Context() as ctx, Location.unknown(): 225 module = Module.create() 226 f32 = F32Type.get() 227 i8 = IntegerType.get_signless(8) 228 with InsertionPoint(module.body): 229 230 @func.FuncOp.from_py_func( 231 MemRefType.get((), f32), 232 MemRefType.get((4, 8), f32), 233 MemRefType.get((4, 8), f32), 234 ) 235 def elemwise_exp_add_on_buffers(lhs, rhs, out): 236 linalg.elemwise_unary(lhs, outs=[out], emit_generic=True) 237 linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True) 238 239 @func.FuncOp.from_py_func( 240 MemRefType.get((), f32), 241 MemRefType.get((4, 8), f32), 242 MemRefType.get((4, 8), f32), 243 ) 244 def elemwise_log_mul_on_buffers(lhs, rhs, out): 245 linalg.elemwise_unary( 246 lhs, outs=[out], fun=UnaryFn.log, emit_generic=True 247 ) 248 linalg.elemwise_binary( 249 out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True 250 ) 251 252 execution_engine = ExecutionEngine(transform(module, elemwise_boiler)) 253 254 # TODO: FFI-based solution to allow testing and printing with python code. 255 # Prepare arguments: one result f32. 256 # Arguments must be passed as pointers. 257 c_float_p = ctypes.c_float * 1 258 res = c_float_p(-1.0) 259 execution_engine.invoke("main", res) 260 261 log("RESULT: ", res[0]) 262 # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846 263 # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0 264 # CHECK: RESULT: 4.71828 265 266 267test_elemwise_generic() 268 269 270def test_fill_builtin(): 271 with Context() as ctx, Location.unknown(): 272 module = Module.create() 273 f32 = F32Type.get() 274 i32 = IntegerType.get_signless(32) 275 with InsertionPoint(module.body): 276 277 @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) 278 def fill_0d_on_buffers(value, out): 279 linalg.fill(value, outs=[out]) 280 281 @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) 282 def fill_1d_on_buffers(value, out): 283 linalg.fill(value, outs=[out]) 284 285 @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) 286 def fill_2d_on_buffers(value, out): 287 linalg.fill(value, outs=[out]) 288 289 execution_engine = ExecutionEngine(transform(module, fill_boiler)) 290 291 # TODO: FFI-based solution to allow testing and printing with python code. 292 # Prepare arguments: one result i32. 293 # Arguments must be passed as pointers. 294 c_int_p = ctypes.c_int * 1 295 res = c_int_p(-1) 296 execution_engine.invoke("main", res) 297 298 log("RESULT: ", res[0]) 299 # CHECK: RESULT: 6 300 301 302test_fill_builtin() 303 304 305def test_fill_generic(): 306 with Context() as ctx, Location.unknown(): 307 module = Module.create() 308 f32 = F32Type.get() 309 i32 = IntegerType.get_signless(32) 310 with InsertionPoint(module.body): 311 312 @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) 313 def fill_0d_on_buffers(value, out): 314 linalg.fill(value, outs=[out], emit_generic=True) 315 316 @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) 317 def fill_1d_on_buffers(value, out): 318 linalg.fill(value, outs=[out], emit_generic=True) 319 320 @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) 321 def fill_2d_on_buffers(value, out): 322 linalg.fill(value, outs=[out], emit_generic=True) 323 324 execution_engine = ExecutionEngine(transform(module, fill_boiler)) 325 326 # TODO: FFI-based solution to allow testing and printing with python code. 327 # Prepare arguments: one result i32. 328 # Arguments must be passed as pointers. 329 c_int_p = ctypes.c_int * 1 330 res = c_int_p(-1) 331 execution_engine.invoke("main", res) 332 333 log("RESULT: ", res[0]) 334 # CHECK: RESULT: 6 335 336 337test_fill_generic() 338 339 340def test_fill_rng_builtin(): 341 with Context() as ctx, Location.unknown(): 342 module = Module.create() 343 f64 = F64Type.get() 344 i32 = IntegerType.get_signless(32) 345 with InsertionPoint(module.body): 346 347 @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) 348 def fill_rng_on_buffers(min, max, seed, out): 349 linalg.fill_rng_2d(min, max, seed, outs=[out]) 350 351 execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) 352 353 # TODO: FFI-based solution to allow testing and printing with python code. 354 # Prepare arguments: one result i32. 355 # Arguments must be passed as pointers. 356 c_int_p = ctypes.c_int * 1 357 res = c_int_p(-1) 358 execution_engine.invoke("main", res) 359 360 log("RESULT: ", res[0]) 361 # CHECK: RESULT: -480 362 363 364test_fill_rng_builtin() 365 366 367def test_fill_rng_generic(): 368 with Context() as ctx, Location.unknown(): 369 module = Module.create() 370 f64 = F64Type.get() 371 i32 = IntegerType.get_signless(32) 372 with InsertionPoint(module.body): 373 374 @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) 375 def fill_rng_on_buffers(min, max, seed, out): 376 linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True) 377 378 execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) 379 380 # TODO: FFI-based solution to allow testing and printing with python code. 381 # Prepare arguments: one result i32. 382 # Arguments must be passed as pointers. 383 c_int_p = ctypes.c_int * 1 384 res = c_int_p(-1) 385 execution_engine.invoke("main", res) 386 387 log("RESULT: ", res[0]) 388 # CHECK: RESULT: -480 389 390 391test_fill_rng_generic() 392 393 394def test_max_pooling_builtin(): 395 with Context() as ctx, Location.unknown(): 396 module = Module.create() 397 f64 = F64Type.get() 398 i32 = IntegerType.get_signless(32) 399 with InsertionPoint(module.body): 400 401 @func.FuncOp.from_py_func( 402 MemRefType.get((1, 4, 16, 1), f64), 403 MemRefType.get((2, 2), f64), 404 MemRefType.get((1, 2, 4, 1), i32), 405 ) 406 def pooling_on_buffers(input, shape, output): 407 linalg.pooling_nhwc_max( 408 input, shape, outs=[output], strides=[2, 4], dilations=[1, 2] 409 ) 410 411 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 412 413 # TODO: FFI-based solution to allow testing and printing with python code. 414 # Prepare arguments: one result i32. 415 # Arguments must be passed as pointers. 416 c_int_p = ctypes.c_int * 1 417 res = c_int_p(-1) 418 execution_engine.invoke("main", res) 419 420 log("RESULT: ", res[0]) 421 # 77 is not selected due to the dilation 2 in the second dimension. 422 # CHECK: RESULT: 42 423 424 425test_max_pooling_builtin() 426 427 428def test_max_pooling_generic(): 429 with Context() as ctx, Location.unknown(): 430 module = Module.create() 431 f64 = F64Type.get() 432 i32 = IntegerType.get_signless(32) 433 with InsertionPoint(module.body): 434 435 @func.FuncOp.from_py_func( 436 MemRefType.get((1, 4, 16, 1), f64), 437 MemRefType.get((2, 2), f64), 438 MemRefType.get((1, 2, 4, 1), i32), 439 ) 440 def pooling_on_buffers(input, shape, output): 441 linalg.pooling_nhwc_max( 442 input, 443 shape, 444 outs=[output], 445 strides=[2, 4], 446 dilations=[1, 2], 447 emit_generic=True, 448 ) 449 450 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 451 452 # TODO: FFI-based solution to allow testing and printing with python code. 453 # Prepare arguments: one result i32. 454 # Arguments must be passed as pointers. 455 c_int_p = ctypes.c_int * 1 456 res = c_int_p(-1) 457 execution_engine.invoke("main", res) 458 459 log("RESULT: ", res[0]) 460 # 77 is not selected due to the dilation 2 in the second dimension. 461 # CHECK: RESULT: 42 462 463 464test_max_pooling_generic() 465 466 467def test_min_pooling_builtin(): 468 with Context() as ctx, Location.unknown(): 469 module = Module.create() 470 f64 = F64Type.get() 471 i32 = IntegerType.get_signless(32) 472 with InsertionPoint(module.body): 473 474 @func.FuncOp.from_py_func( 475 MemRefType.get((1, 4, 16, 1), f64), 476 MemRefType.get((2, 2), f64), 477 MemRefType.get((1, 2, 4, 1), i32), 478 ) 479 # Set the strides and use the default dilations. 480 def pooling_on_buffers(input, shape, output): 481 linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4]) 482 483 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 484 485 # TODO: FFI-based solution to allow testing and printing with python code. 486 # Prepare arguments: one result i32. 487 # Arguments must be passed as pointers. 488 c_int_p = ctypes.c_int * 1 489 res = c_int_p(-1) 490 execution_engine.invoke("main", res) 491 492 log("RESULT: ", res[0]) 493 # CHECK: RESULT: -13 494 495 496test_min_pooling_builtin() 497 498 499def test_min_pooling_generic(): 500 with Context() as ctx, Location.unknown(): 501 module = Module.create() 502 f64 = F64Type.get() 503 i32 = IntegerType.get_signless(32) 504 with InsertionPoint(module.body): 505 506 @func.FuncOp.from_py_func( 507 MemRefType.get((1, 4, 16, 1), f64), 508 MemRefType.get((2, 2), f64), 509 MemRefType.get((1, 2, 4, 1), i32), 510 ) 511 # Set the strides and use the default dilations. 512 def pooling_on_buffers(input, shape, output): 513 linalg.pooling_nhwc_min( 514 input, shape, outs=[output], strides=[2, 4], emit_generic=True 515 ) 516 517 execution_engine = ExecutionEngine(transform(module, pooling_boiler)) 518 519 # TODO: FFI-based solution to allow testing and printing with python code. 520 # Prepare arguments: one result i32. 521 # Arguments must be passed as pointers. 522 c_int_p = ctypes.c_int * 1 523 res = c_int_p(-1) 524 execution_engine.invoke("main", res) 525 526 log("RESULT: ", res[0]) 527 # CHECK: RESULT: -13 528 529 530test_min_pooling_generic() 531