xref: /llvm-project/mlir/test/python/dialects/linalg/opdsl/emit_misc.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
10d0371f5Sgysit# RUN: %PYTHON %s | FileCheck %s
20d0371f5Sgysit
30d0371f5Sgysitfrom mlir.ir import *
40d0371f5Sgysitfrom mlir.dialects import builtin
523aa5a74SRiver Riddlefrom mlir.dialects import func
60d0371f5Sgysitfrom mlir.dialects import linalg
70d0371f5Sgysit
80d0371f5Sgysitfrom mlir.dialects.linalg.opdsl.lang import *
90d0371f5Sgysit
100d0371f5Sgysit# This tests miscellaneous features of the emitter that are not tested by the
11f4939d56Sgysit# fill, matmul, convolution, or pooling tests. The features include:
120d0371f5Sgysit# - constant defined in the body
130d0371f5Sgysit# - fix/predefined types
1413d33071SBixia Zheng# - some math/arith functions, including abs, ceil, exp, floor, log, and negf
150d0371f5Sgysit# - custom op names.
160d0371f5Sgysit
1715757ea8Sgysit
180d0371f5Sgysit@linalg_structured_op
19f4939d56Sgysitdef test_const(O=TensorDef(F32, S.M, S.N, output=True)):
20f4939d56Sgysit    O[D.m, D.n] = TypeFn.cast_unsigned(F32, const(42)) + TypeFn.cast_unsigned(
21*f9008e63STobias Hieta        F32, const(2.3283064e-10)
22*f9008e63STobias Hieta    )
230d0371f5Sgysit
240d0371f5Sgysit
250d0371f5Sgysit@linalg_structured_op
26f4939d56Sgysitdef test_index(O=TensorDef(I32, S.M, S.N, output=True)):
27f4939d56Sgysit    O[D.m, D.n] = TypeFn.cast_signed(I32, index(D.m)) + TypeFn.cast_signed(
28*f9008e63STobias Hieta        I32, index(D.n)
29*f9008e63STobias Hieta    )
30f4939d56Sgysit
31f4939d56Sgysit
32f4939d56Sgysit@linalg_structured_op
33f4939d56Sgysitdef elemwise_unary_poly(
34f4939d56Sgysit    I=TensorDef(T),
35f4939d56Sgysit    O=TensorDef(U, output=True),
36f4939d56Sgysit    fun=UnaryFnAttrDef(default=UnaryFn.exp),
37*f9008e63STobias Hieta    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
38*f9008e63STobias Hieta):
39f4939d56Sgysit    O[None] = fun(cast(U, I[None]))
400d0371f5Sgysit
410d0371f5Sgysit
420d0371f5Sgysit@linalg_structured_op(op_name="custom_op_name")
430d0371f5Sgysitdef non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
440d0371f5Sgysit    O[D.n] = I[D.n]
450d0371f5Sgysit
460d0371f5Sgysit
470d0371f5Sgysitwith Context() as ctx, Location.unknown():
480d0371f5Sgysit    module = Module.create()
490d0371f5Sgysit    f32 = F32Type.get()
5048f4407cSbixia1    c32 = ComplexType.get(f32)
510d0371f5Sgysit    i32 = IntegerType.get_signless(32)
520d0371f5Sgysit    with InsertionPoint(module.body):
530d0371f5Sgysit
54f4939d56Sgysit        # CHECK-LABEL: @test_f32_const
55f4939d56Sgysit        # CHECK-DAG:    %[[CST0:.+]] = arith.constant 42 : i64
56f4939d56Sgysit        # CHECK-DAG:    %[[CST0_CAST:.+]] = arith.uitofp %[[CST0]] : i64 to f32
57f4939d56Sgysit        # CHECK-DAG:    %[[CST1:.+]] = arith.constant 2.3283063999999999E-10 : f64
58f4939d56Sgysit        # CHECK-DAG:    %[[CST1_CAST:.+]] = arith.truncf %[[CST1]] : f64 to f32
59f4939d56Sgysit        # CHECK-DAG:    %[[SUM:.+]] = arith.addf %[[CST0_CAST]], %[[CST1_CAST]] : f32
60f4939d56Sgysit        # CHECK-NEXT:   linalg.yield %[[SUM]] : f32
6136550692SRiver Riddle        @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32))
62f4939d56Sgysit        def test_f32_const(init_result):
63f4939d56Sgysit            return test_const(outs=[init_result])
640d0371f5Sgysit
65f4939d56Sgysit        # CHECK-LABEL: @test_i32_index
66f4939d56Sgysit        # CHECK-DAG:    %[[IDX0:.+]] = linalg.index 0 : index
67f4939d56Sgysit        # CHECK-DAG:    %[[IDX1:.+]] = linalg.index 1 : index
68f4939d56Sgysit        # CHECK-DAG:    %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32
69f4939d56Sgysit        # CHECK-DAG:    %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32
70f4939d56Sgysit        # CHECK-DAG:    %[[SUM:.+]] = arith.addi %[[IDX0_CAST]], %[[IDX1_CAST]] : i32
71f4939d56Sgysit        # CHECK-NEXT:   linalg.yield %[[SUM]] : i32
7236550692SRiver Riddle        @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32))
73f4939d56Sgysit        def test_i32_index(init_result):
74f4939d56Sgysit            return test_index(outs=[init_result])
75f4939d56Sgysit
76f4939d56Sgysit        # CHECK-LABEL: @test_f32_elemwise_exp
770d0371f5Sgysit        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
780d0371f5Sgysit        # CHECK-NEXT:   %[[EXP:.+]] = math.exp %[[IN]] : f32
79f4939d56Sgysit        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
80f4939d56Sgysit        # CHECK-NEXT: -> tensor<4x16xf32>
8136550692SRiver Riddle        @func.FuncOp.from_py_func(
82*f9008e63STobias Hieta            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
83*f9008e63STobias Hieta        )
84f4939d56Sgysit        def test_f32_elemwise_exp(input, init_result):
85f4939d56Sgysit            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
86f4939d56Sgysit
87f4939d56Sgysit        # CHECK-LABEL: @test_f32_elemwise_log
88f4939d56Sgysit        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
89f4939d56Sgysit        # CHECK-NEXT:   %[[LOG:.+]] = math.log %[[IN]] : f32
900d0371f5Sgysit        # CHECK-NEXT:   linalg.yield %[[LOG]] : f32
910d0371f5Sgysit        # CHECK-NEXT: -> tensor<4x16xf32>
9236550692SRiver Riddle        @func.FuncOp.from_py_func(
93*f9008e63STobias Hieta            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
94*f9008e63STobias Hieta        )
95f4939d56Sgysit        def test_f32_elemwise_log(input, init_result):
96f4939d56Sgysit            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
970d0371f5Sgysit
9813d33071SBixia Zheng        # CHECK-LABEL: @test_f32_elemwise_abs
9913d33071SBixia Zheng        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
10000f7096dSJeff Niu        # CHECK-NEXT:   %[[EXP:.+]] = math.absf %[[IN]] : f32
10113d33071SBixia Zheng        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
10213d33071SBixia Zheng        # CHECK-NEXT: -> tensor<4x16xf32>
10336550692SRiver Riddle        @func.FuncOp.from_py_func(
104*f9008e63STobias Hieta            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
105*f9008e63STobias Hieta        )
10613d33071SBixia Zheng        def test_f32_elemwise_abs(input, init_result):
10713d33071SBixia Zheng            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
10813d33071SBixia Zheng
10913d33071SBixia Zheng        # CHECK-LABEL: @test_f32_elemwise_ceil
11013d33071SBixia Zheng        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
11113d33071SBixia Zheng        # CHECK-NEXT:   %[[EXP:.+]] = math.ceil %[[IN]] : f32
11213d33071SBixia Zheng        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
11313d33071SBixia Zheng        # CHECK-NEXT: -> tensor<4x16xf32>
11436550692SRiver Riddle        @func.FuncOp.from_py_func(
115*f9008e63STobias Hieta            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
116*f9008e63STobias Hieta        )
11713d33071SBixia Zheng        def test_f32_elemwise_ceil(input, init_result):
11813d33071SBixia Zheng            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil)
11913d33071SBixia Zheng
12013d33071SBixia Zheng        # CHECK-LABEL: @test_f32_elemwise_floor
12113d33071SBixia Zheng        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
12213d33071SBixia Zheng        # CHECK-NEXT:   %[[EXP:.+]] = math.floor %[[IN]] : f32
12313d33071SBixia Zheng        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
12413d33071SBixia Zheng        # CHECK-NEXT: -> tensor<4x16xf32>
12536550692SRiver Riddle        @func.FuncOp.from_py_func(
126*f9008e63STobias Hieta            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
127*f9008e63STobias Hieta        )
12813d33071SBixia Zheng        def test_f32_elemwise_floor(input, init_result):
12913d33071SBixia Zheng            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor)
13013d33071SBixia Zheng
13113d33071SBixia Zheng        # CHECK-LABEL: @test_f32_elemwise_neg
13213d33071SBixia Zheng        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
13313d33071SBixia Zheng        # CHECK-NEXT:   %[[EXP:.+]] = arith.negf %[[IN]] : f32
13413d33071SBixia Zheng        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
13513d33071SBixia Zheng        # CHECK-NEXT: -> tensor<4x16xf32>
13636550692SRiver Riddle        @func.FuncOp.from_py_func(
137*f9008e63STobias Hieta            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
138*f9008e63STobias Hieta        )
13913d33071SBixia Zheng        def test_f32_elemwise_neg(input, init_result):
14013d33071SBixia Zheng            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
14113d33071SBixia Zheng
14248f4407cSbixia1        # CHECK-LABEL: @test_c32_elemwise_neg
14348f4407cSbixia1        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
14448f4407cSbixia1        # CHECK-NEXT:   %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
14548f4407cSbixia1        # CHECK-NEXT:   linalg.yield %[[EXP]] : complex<f32>
14648f4407cSbixia1        # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
14748f4407cSbixia1        @func.FuncOp.from_py_func(
148*f9008e63STobias Hieta            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
149*f9008e63STobias Hieta        )
15048f4407cSbixia1        def test_c32_elemwise_neg(input, init_result):
15148f4407cSbixia1            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
15248f4407cSbixia1
1530d0371f5Sgysit        # Just check that we don't assert out on name mismatch.
1540d0371f5Sgysit        # CHECK-LABEL: @test_non_default_op_name
15536550692SRiver Riddle        @func.FuncOp.from_py_func(
156*f9008e63STobias Hieta            RankedTensorType.get((42,), f32), RankedTensorType.get((42,), f32)
157*f9008e63STobias Hieta        )
1580d0371f5Sgysit        def test_non_default_op_name(input, init_result):
1590d0371f5Sgysit            return non_default_op_name(input, outs=[init_result])
1600d0371f5Sgysit
1610d0371f5Sgysit
1620d0371f5Sgysitprint(module)
163