xref: /llvm-project/mlir/test/python/dialects/arith_dialect.py (revision 5d59fa90ce225814739d9b51ba37e1cca9204cad)
1# RUN: %PYTHON %s | FileCheck %s
2from functools import partialmethod
3
4from mlir.ir import *
5import mlir.dialects.arith as arith
6import mlir.dialects.func as func
7from array import array
8
9
10def run(f):
11    print("\nTEST:", f.__name__)
12    f()
13
14
15# CHECK-LABEL: TEST: testConstantOp
16@run
17def testConstantOps():
18    with Context() as ctx, Location.unknown():
19        module = Module.create()
20        with InsertionPoint(module.body):
21            arith.ConstantOp(value=42.42, result=F32Type.get())
22        # CHECK:         %cst = arith.constant 4.242000e+01 : f32
23        print(module)
24
25
26# CHECK-LABEL: TEST: testFastMathFlags
27@run
28def testFastMathFlags():
29    with Context() as ctx, Location.unknown():
30        module = Module.create()
31        with InsertionPoint(module.body):
32            a = arith.ConstantOp(value=42.42, result=F32Type.get())
33            r = arith.AddFOp(
34                a, a, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf
35            )
36            # CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32
37            print(r)
38
39
40# CHECK-LABEL: TEST: testArithValue
41@run
42def testArithValue():
43    def _binary_op(lhs, rhs, op: str) -> "ArithValue":
44        op = op.capitalize()
45        if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
46            op += "F"
47        elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type(
48            lhs.type
49        ):
50            op += "I"
51        else:
52            raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
53
54        op = getattr(arith, f"{op}Op")
55        return op(lhs, rhs).result
56
57    @register_value_caster(F16Type.static_typeid)
58    @register_value_caster(F32Type.static_typeid)
59    @register_value_caster(F64Type.static_typeid)
60    @register_value_caster(IntegerType.static_typeid)
61    class ArithValue(Value):
62        def __init__(self, v):
63            super().__init__(v)
64
65        __add__ = partialmethod(_binary_op, op="add")
66        __sub__ = partialmethod(_binary_op, op="sub")
67        __mul__ = partialmethod(_binary_op, op="mul")
68
69        def __str__(self):
70            return super().__str__().replace(Value.__name__, ArithValue.__name__)
71
72    with Context() as ctx, Location.unknown():
73        module = Module.create()
74        f16_t = F16Type.get()
75        f32_t = F32Type.get()
76        f64_t = F64Type.get()
77
78        with InsertionPoint(module.body):
79            a = arith.constant(f16_t, 42.42)
80            # CHECK: ArithValue(%cst = arith.constant 4.240
81            print(a)
82
83            b = a + a
84            # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
85            print(b)
86
87            a = arith.constant(f32_t, 42.42)
88            b = a - a
89            # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
90            print(b)
91
92            a = arith.constant(f64_t, 42.42)
93            b = a * a
94            # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
95            print(b)
96
97
98# CHECK-LABEL: TEST: testArrayConstantConstruction
99@run
100def testArrayConstantConstruction():
101    with Context(), Location.unknown():
102        module = Module.create()
103        with InsertionPoint(module.body):
104            i32_array = array("i", [1, 2, 3, 4])
105            i32 = IntegerType.get_signless(32)
106            vec_i32 = VectorType.get([2, 2], i32)
107            arith.constant(vec_i32, i32_array)
108            arith.ConstantOp(vec_i32, DenseIntElementsAttr.get(i32_array, type=vec_i32))
109
110            # "q" is the equivalent of `long long` in C and requires at least
111            # 64 bit width integers on both Linux and Windows.
112            i64_array = array("q", [5, 6, 7, 8])
113            i64 = IntegerType.get_signless(64)
114            vec_i64 = VectorType.get([1, 4], i64)
115            arith.constant(vec_i64, i64_array)
116            arith.ConstantOp(vec_i64, DenseIntElementsAttr.get(i64_array, type=vec_i64))
117
118            f32_array = array("f", [1.0, 2.0, 3.0, 4.0])
119            f32 = F32Type.get()
120            vec_f32 = VectorType.get([4, 1], f32)
121            arith.constant(vec_f32, f32_array)
122            arith.ConstantOp(vec_f32, DenseFPElementsAttr.get(f32_array, type=vec_f32))
123
124            f64_array = array("d", [1.0, 2.0, 3.0, 4.0])
125            f64 = F64Type.get()
126            vec_f64 = VectorType.get([2, 1, 2], f64)
127            arith.constant(vec_f64, f64_array)
128            arith.ConstantOp(vec_f64, DenseFPElementsAttr.get(f64_array, type=vec_f64))
129
130        # CHECK-COUNT-2: arith.constant dense<[{{\[}}1, 2], [3, 4]]> : vector<2x2xi32>
131        # CHECK-COUNT-2: arith.constant dense<[{{\[}}5, 6, 7, 8]]> : vector<1x4xi64>
132        # CHECK-COUNT-2: arith.constant dense<[{{\[}}1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00]]> : vector<4x1xf32>
133        # CHECK-COUNT-2: arith.constant dense<[{{\[}}[1.000000e+00, 2.000000e+00]], [{{\[}}3.000000e+00, 4.000000e+00]]]> : vector<2x1x2xf64>
134        print(module)
135