10d0371f5Sgysit# RUN: %PYTHON %s | FileCheck %s 20d0371f5Sgysit 30d0371f5Sgysitfrom mlir.ir import * 40d0371f5Sgysitfrom mlir.dialects import builtin 523aa5a74SRiver Riddlefrom mlir.dialects import func 60d0371f5Sgysitfrom mlir.dialects import linalg 70d0371f5Sgysit 80d0371f5Sgysitfrom mlir.dialects.linalg.opdsl.lang import * 90d0371f5Sgysit 100d0371f5Sgysit# This tests miscellaneous features of the emitter that are not tested by the 11f4939d56Sgysit# fill, matmul, convolution, or pooling tests. The features include: 120d0371f5Sgysit# - constant defined in the body 130d0371f5Sgysit# - fix/predefined types 1413d33071SBixia Zheng# - some math/arith functions, including abs, ceil, exp, floor, log, and negf 150d0371f5Sgysit# - custom op names. 160d0371f5Sgysit 1715757ea8Sgysit 180d0371f5Sgysit@linalg_structured_op 19f4939d56Sgysitdef test_const(O=TensorDef(F32, S.M, S.N, output=True)): 20f4939d56Sgysit O[D.m, D.n] = TypeFn.cast_unsigned(F32, const(42)) + TypeFn.cast_unsigned( 21*f9008e63STobias Hieta F32, const(2.3283064e-10) 22*f9008e63STobias Hieta ) 230d0371f5Sgysit 240d0371f5Sgysit 250d0371f5Sgysit@linalg_structured_op 26f4939d56Sgysitdef test_index(O=TensorDef(I32, S.M, S.N, output=True)): 27f4939d56Sgysit O[D.m, D.n] = TypeFn.cast_signed(I32, index(D.m)) + TypeFn.cast_signed( 28*f9008e63STobias Hieta I32, index(D.n) 29*f9008e63STobias Hieta ) 30f4939d56Sgysit 31f4939d56Sgysit 32f4939d56Sgysit@linalg_structured_op 33f4939d56Sgysitdef elemwise_unary_poly( 34f4939d56Sgysit I=TensorDef(T), 35f4939d56Sgysit O=TensorDef(U, output=True), 36f4939d56Sgysit fun=UnaryFnAttrDef(default=UnaryFn.exp), 37*f9008e63STobias Hieta cast=TypeFnAttrDef(default=TypeFn.cast_signed), 38*f9008e63STobias Hieta): 39f4939d56Sgysit O[None] = fun(cast(U, I[None])) 400d0371f5Sgysit 410d0371f5Sgysit 420d0371f5Sgysit@linalg_structured_op(op_name="custom_op_name") 430d0371f5Sgysitdef non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)): 440d0371f5Sgysit O[D.n] = I[D.n] 450d0371f5Sgysit 460d0371f5Sgysit 470d0371f5Sgysitwith Context() as ctx, Location.unknown(): 480d0371f5Sgysit module = Module.create() 490d0371f5Sgysit f32 = F32Type.get() 5048f4407cSbixia1 c32 = ComplexType.get(f32) 510d0371f5Sgysit i32 = IntegerType.get_signless(32) 520d0371f5Sgysit with InsertionPoint(module.body): 530d0371f5Sgysit 54f4939d56Sgysit # CHECK-LABEL: @test_f32_const 55f4939d56Sgysit # CHECK-DAG: %[[CST0:.+]] = arith.constant 42 : i64 56f4939d56Sgysit # CHECK-DAG: %[[CST0_CAST:.+]] = arith.uitofp %[[CST0]] : i64 to f32 57f4939d56Sgysit # CHECK-DAG: %[[CST1:.+]] = arith.constant 2.3283063999999999E-10 : f64 58f4939d56Sgysit # CHECK-DAG: %[[CST1_CAST:.+]] = arith.truncf %[[CST1]] : f64 to f32 59f4939d56Sgysit # CHECK-DAG: %[[SUM:.+]] = arith.addf %[[CST0_CAST]], %[[CST1_CAST]] : f32 60f4939d56Sgysit # CHECK-NEXT: linalg.yield %[[SUM]] : f32 6136550692SRiver Riddle @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32)) 62f4939d56Sgysit def test_f32_const(init_result): 63f4939d56Sgysit return test_const(outs=[init_result]) 640d0371f5Sgysit 65f4939d56Sgysit # CHECK-LABEL: @test_i32_index 66f4939d56Sgysit # CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index 67f4939d56Sgysit # CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index 68f4939d56Sgysit # CHECK-DAG: %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32 69f4939d56Sgysit # CHECK-DAG: %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32 70f4939d56Sgysit # CHECK-DAG: %[[SUM:.+]] = arith.addi %[[IDX0_CAST]], %[[IDX1_CAST]] : i32 71f4939d56Sgysit # CHECK-NEXT: linalg.yield %[[SUM]] : i32 7236550692SRiver Riddle @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32)) 73f4939d56Sgysit def test_i32_index(init_result): 74f4939d56Sgysit return test_index(outs=[init_result]) 75f4939d56Sgysit 76f4939d56Sgysit # CHECK-LABEL: @test_f32_elemwise_exp 770d0371f5Sgysit # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) 780d0371f5Sgysit # CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32 79f4939d56Sgysit # CHECK-NEXT: linalg.yield %[[EXP]] : f32 80f4939d56Sgysit # CHECK-NEXT: -> tensor<4x16xf32> 8136550692SRiver Riddle @func.FuncOp.from_py_func( 82*f9008e63STobias Hieta RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) 83*f9008e63STobias Hieta ) 84f4939d56Sgysit def test_f32_elemwise_exp(input, init_result): 85f4939d56Sgysit return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp) 86f4939d56Sgysit 87f4939d56Sgysit # CHECK-LABEL: @test_f32_elemwise_log 88f4939d56Sgysit # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) 89f4939d56Sgysit # CHECK-NEXT: %[[LOG:.+]] = math.log %[[IN]] : f32 900d0371f5Sgysit # CHECK-NEXT: linalg.yield %[[LOG]] : f32 910d0371f5Sgysit # CHECK-NEXT: -> tensor<4x16xf32> 9236550692SRiver Riddle @func.FuncOp.from_py_func( 93*f9008e63STobias Hieta RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) 94*f9008e63STobias Hieta ) 95f4939d56Sgysit def test_f32_elemwise_log(input, init_result): 96f4939d56Sgysit return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log) 970d0371f5Sgysit 9813d33071SBixia Zheng # CHECK-LABEL: @test_f32_elemwise_abs 9913d33071SBixia Zheng # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) 10000f7096dSJeff Niu # CHECK-NEXT: %[[EXP:.+]] = math.absf %[[IN]] : f32 10113d33071SBixia Zheng # CHECK-NEXT: linalg.yield %[[EXP]] : f32 10213d33071SBixia Zheng # CHECK-NEXT: -> tensor<4x16xf32> 10336550692SRiver Riddle @func.FuncOp.from_py_func( 104*f9008e63STobias Hieta RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) 105*f9008e63STobias Hieta ) 10613d33071SBixia Zheng def test_f32_elemwise_abs(input, init_result): 10713d33071SBixia Zheng return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs) 10813d33071SBixia Zheng 10913d33071SBixia Zheng # CHECK-LABEL: @test_f32_elemwise_ceil 11013d33071SBixia Zheng # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) 11113d33071SBixia Zheng # CHECK-NEXT: %[[EXP:.+]] = math.ceil %[[IN]] : f32 11213d33071SBixia Zheng # CHECK-NEXT: linalg.yield %[[EXP]] : f32 11313d33071SBixia Zheng # CHECK-NEXT: -> tensor<4x16xf32> 11436550692SRiver Riddle @func.FuncOp.from_py_func( 115*f9008e63STobias Hieta RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) 116*f9008e63STobias Hieta ) 11713d33071SBixia Zheng def test_f32_elemwise_ceil(input, init_result): 11813d33071SBixia Zheng return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil) 11913d33071SBixia Zheng 12013d33071SBixia Zheng # CHECK-LABEL: @test_f32_elemwise_floor 12113d33071SBixia Zheng # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) 12213d33071SBixia Zheng # CHECK-NEXT: %[[EXP:.+]] = math.floor %[[IN]] : f32 12313d33071SBixia Zheng # CHECK-NEXT: linalg.yield %[[EXP]] : f32 12413d33071SBixia Zheng # CHECK-NEXT: -> tensor<4x16xf32> 12536550692SRiver Riddle @func.FuncOp.from_py_func( 126*f9008e63STobias Hieta RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) 127*f9008e63STobias Hieta ) 12813d33071SBixia Zheng def test_f32_elemwise_floor(input, init_result): 12913d33071SBixia Zheng return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor) 13013d33071SBixia Zheng 13113d33071SBixia Zheng # CHECK-LABEL: @test_f32_elemwise_neg 13213d33071SBixia Zheng # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) 13313d33071SBixia Zheng # CHECK-NEXT: %[[EXP:.+]] = arith.negf %[[IN]] : f32 13413d33071SBixia Zheng # CHECK-NEXT: linalg.yield %[[EXP]] : f32 13513d33071SBixia Zheng # CHECK-NEXT: -> tensor<4x16xf32> 13636550692SRiver Riddle @func.FuncOp.from_py_func( 137*f9008e63STobias Hieta RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) 138*f9008e63STobias Hieta ) 13913d33071SBixia Zheng def test_f32_elemwise_neg(input, init_result): 14013d33071SBixia Zheng return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf) 14113d33071SBixia Zheng 14248f4407cSbixia1 # CHECK-LABEL: @test_c32_elemwise_neg 14348f4407cSbixia1 # CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>) 14448f4407cSbixia1 # CHECK-NEXT: %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32> 14548f4407cSbixia1 # CHECK-NEXT: linalg.yield %[[EXP]] : complex<f32> 14648f4407cSbixia1 # CHECK-NEXT: -> tensor<4x16xcomplex<f32>> 14748f4407cSbixia1 @func.FuncOp.from_py_func( 148*f9008e63STobias Hieta RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32) 149*f9008e63STobias Hieta ) 15048f4407cSbixia1 def test_c32_elemwise_neg(input, init_result): 15148f4407cSbixia1 return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf) 15248f4407cSbixia1 1530d0371f5Sgysit # Just check that we don't assert out on name mismatch. 1540d0371f5Sgysit # CHECK-LABEL: @test_non_default_op_name 15536550692SRiver Riddle @func.FuncOp.from_py_func( 156*f9008e63STobias Hieta RankedTensorType.get((42,), f32), RankedTensorType.get((42,), f32) 157*f9008e63STobias Hieta ) 1580d0371f5Sgysit def test_non_default_op_name(input, init_result): 1590d0371f5Sgysit return non_default_op_name(input, outs=[init_result]) 1600d0371f5Sgysit 1610d0371f5Sgysit 1620d0371f5Sgysitprint(module) 163