xref: /llvm-project/mlir/test/python/dialects/arith_dialect.py (revision 5d59fa90ce225814739d9b51ba37e1cca9204cad)
1ec5def5eSMehdi Amini# RUN: %PYTHON %s | FileCheck %s
27c850867SMaksim Leventalfrom functools import partialmethod
3ec5def5eSMehdi Amini
4ec5def5eSMehdi Aminifrom mlir.ir import *
5ec5def5eSMehdi Aminiimport mlir.dialects.arith as arith
67c850867SMaksim Leventalimport mlir.dialects.func as func
7*5d59fa90SOleksandr "Alex" Zinenkofrom array import array
8ec5def5eSMehdi Amini
9f9008e63STobias Hieta
10ec5def5eSMehdi Aminidef run(f):
11ec5def5eSMehdi Amini    print("\nTEST:", f.__name__)
12ec5def5eSMehdi Amini    f()
13ec5def5eSMehdi Amini
14f9008e63STobias Hieta
15ec5def5eSMehdi Amini# CHECK-LABEL: TEST: testConstantOp
16ec5def5eSMehdi Amini@run
17ec5def5eSMehdi Aminidef testConstantOps():
18ec5def5eSMehdi Amini    with Context() as ctx, Location.unknown():
19ec5def5eSMehdi Amini        module = Module.create()
20ec5def5eSMehdi Amini        with InsertionPoint(module.body):
21ec5def5eSMehdi Amini            arith.ConstantOp(value=42.42, result=F32Type.get())
22ec5def5eSMehdi Amini        # CHECK:         %cst = arith.constant 4.242000e+01 : f32
23ec5def5eSMehdi Amini        print(module)
2492233062Smax
2592233062Smax
2692233062Smax# CHECK-LABEL: TEST: testFastMathFlags
2792233062Smax@run
2892233062Smaxdef testFastMathFlags():
2992233062Smax    with Context() as ctx, Location.unknown():
3092233062Smax        module = Module.create()
3192233062Smax        with InsertionPoint(module.body):
3292233062Smax            a = arith.ConstantOp(value=42.42, result=F32Type.get())
3392233062Smax            r = arith.AddFOp(
3492233062Smax                a, a, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf
3592233062Smax            )
3692233062Smax            # CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32
3792233062Smax            print(r)
38a2288a89SMaksim Levental
39a2288a89SMaksim Levental
407c850867SMaksim Levental# CHECK-LABEL: TEST: testArithValue
41a2288a89SMaksim Levental@run
427c850867SMaksim Leventaldef testArithValue():
437c850867SMaksim Levental    def _binary_op(lhs, rhs, op: str) -> "ArithValue":
447c850867SMaksim Levental        op = op.capitalize()
457c850867SMaksim Levental        if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
467c850867SMaksim Levental            op += "F"
477c850867SMaksim Levental        elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type(
487c850867SMaksim Levental            lhs.type
497c850867SMaksim Levental        ):
507c850867SMaksim Levental            op += "I"
517c850867SMaksim Levental        else:
527c850867SMaksim Levental            raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
537c850867SMaksim Levental
547c850867SMaksim Levental        op = getattr(arith, f"{op}Op")
557c850867SMaksim Levental        return op(lhs, rhs).result
567c850867SMaksim Levental
577c850867SMaksim Levental    @register_value_caster(F16Type.static_typeid)
587c850867SMaksim Levental    @register_value_caster(F32Type.static_typeid)
597c850867SMaksim Levental    @register_value_caster(F64Type.static_typeid)
607c850867SMaksim Levental    @register_value_caster(IntegerType.static_typeid)
617c850867SMaksim Levental    class ArithValue(Value):
627c850867SMaksim Levental        def __init__(self, v):
637c850867SMaksim Levental            super().__init__(v)
647c850867SMaksim Levental
657c850867SMaksim Levental        __add__ = partialmethod(_binary_op, op="add")
667c850867SMaksim Levental        __sub__ = partialmethod(_binary_op, op="sub")
677c850867SMaksim Levental        __mul__ = partialmethod(_binary_op, op="mul")
687c850867SMaksim Levental
697c850867SMaksim Levental        def __str__(self):
707c850867SMaksim Levental            return super().__str__().replace(Value.__name__, ArithValue.__name__)
717c850867SMaksim Levental
72a2288a89SMaksim Levental    with Context() as ctx, Location.unknown():
73a2288a89SMaksim Levental        module = Module.create()
747c850867SMaksim Levental        f16_t = F16Type.get()
75a2288a89SMaksim Levental        f32_t = F32Type.get()
767c850867SMaksim Levental        f64_t = F64Type.get()
77a2288a89SMaksim Levental
78a2288a89SMaksim Levental        with InsertionPoint(module.body):
79537b2aa2SMaksim Levental            a = arith.constant(f16_t, 42.42)
807c850867SMaksim Levental            # CHECK: ArithValue(%cst = arith.constant 4.240
81a2288a89SMaksim Levental            print(a)
827c850867SMaksim Levental
837c850867SMaksim Levental            b = a + a
847c850867SMaksim Levental            # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
857c850867SMaksim Levental            print(b)
867c850867SMaksim Levental
87537b2aa2SMaksim Levental            a = arith.constant(f32_t, 42.42)
887c850867SMaksim Levental            b = a - a
897c850867SMaksim Levental            # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
907c850867SMaksim Levental            print(b)
917c850867SMaksim Levental
92537b2aa2SMaksim Levental            a = arith.constant(f64_t, 42.42)
937c850867SMaksim Levental            b = a * a
947c850867SMaksim Levental            # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
957c850867SMaksim Levental            print(b)
96*5d59fa90SOleksandr "Alex" Zinenko
97*5d59fa90SOleksandr "Alex" Zinenko
98*5d59fa90SOleksandr "Alex" Zinenko# CHECK-LABEL: TEST: testArrayConstantConstruction
99*5d59fa90SOleksandr "Alex" Zinenko@run
100*5d59fa90SOleksandr "Alex" Zinenkodef testArrayConstantConstruction():
101*5d59fa90SOleksandr "Alex" Zinenko    with Context(), Location.unknown():
102*5d59fa90SOleksandr "Alex" Zinenko        module = Module.create()
103*5d59fa90SOleksandr "Alex" Zinenko        with InsertionPoint(module.body):
104*5d59fa90SOleksandr "Alex" Zinenko            i32_array = array("i", [1, 2, 3, 4])
105*5d59fa90SOleksandr "Alex" Zinenko            i32 = IntegerType.get_signless(32)
106*5d59fa90SOleksandr "Alex" Zinenko            vec_i32 = VectorType.get([2, 2], i32)
107*5d59fa90SOleksandr "Alex" Zinenko            arith.constant(vec_i32, i32_array)
108*5d59fa90SOleksandr "Alex" Zinenko            arith.ConstantOp(vec_i32, DenseIntElementsAttr.get(i32_array, type=vec_i32))
109*5d59fa90SOleksandr "Alex" Zinenko
110*5d59fa90SOleksandr "Alex" Zinenko            # "q" is the equivalent of `long long` in C and requires at least
111*5d59fa90SOleksandr "Alex" Zinenko            # 64 bit width integers on both Linux and Windows.
112*5d59fa90SOleksandr "Alex" Zinenko            i64_array = array("q", [5, 6, 7, 8])
113*5d59fa90SOleksandr "Alex" Zinenko            i64 = IntegerType.get_signless(64)
114*5d59fa90SOleksandr "Alex" Zinenko            vec_i64 = VectorType.get([1, 4], i64)
115*5d59fa90SOleksandr "Alex" Zinenko            arith.constant(vec_i64, i64_array)
116*5d59fa90SOleksandr "Alex" Zinenko            arith.ConstantOp(vec_i64, DenseIntElementsAttr.get(i64_array, type=vec_i64))
117*5d59fa90SOleksandr "Alex" Zinenko
118*5d59fa90SOleksandr "Alex" Zinenko            f32_array = array("f", [1.0, 2.0, 3.0, 4.0])
119*5d59fa90SOleksandr "Alex" Zinenko            f32 = F32Type.get()
120*5d59fa90SOleksandr "Alex" Zinenko            vec_f32 = VectorType.get([4, 1], f32)
121*5d59fa90SOleksandr "Alex" Zinenko            arith.constant(vec_f32, f32_array)
122*5d59fa90SOleksandr "Alex" Zinenko            arith.ConstantOp(vec_f32, DenseFPElementsAttr.get(f32_array, type=vec_f32))
123*5d59fa90SOleksandr "Alex" Zinenko
124*5d59fa90SOleksandr "Alex" Zinenko            f64_array = array("d", [1.0, 2.0, 3.0, 4.0])
125*5d59fa90SOleksandr "Alex" Zinenko            f64 = F64Type.get()
126*5d59fa90SOleksandr "Alex" Zinenko            vec_f64 = VectorType.get([2, 1, 2], f64)
127*5d59fa90SOleksandr "Alex" Zinenko            arith.constant(vec_f64, f64_array)
128*5d59fa90SOleksandr "Alex" Zinenko            arith.ConstantOp(vec_f64, DenseFPElementsAttr.get(f64_array, type=vec_f64))
129*5d59fa90SOleksandr "Alex" Zinenko
130*5d59fa90SOleksandr "Alex" Zinenko        # CHECK-COUNT-2: arith.constant dense<[{{\[}}1, 2], [3, 4]]> : vector<2x2xi32>
131*5d59fa90SOleksandr "Alex" Zinenko        # CHECK-COUNT-2: arith.constant dense<[{{\[}}5, 6, 7, 8]]> : vector<1x4xi64>
132*5d59fa90SOleksandr "Alex" Zinenko        # CHECK-COUNT-2: arith.constant dense<[{{\[}}1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00]]> : vector<4x1xf32>
133*5d59fa90SOleksandr "Alex" Zinenko        # CHECK-COUNT-2: arith.constant dense<[{{\[}}[1.000000e+00, 2.000000e+00]], [{{\[}}3.000000e+00, 4.000000e+00]]]> : vector<2x1x2xf64>
134*5d59fa90SOleksandr "Alex" Zinenko        print(module)
135