1ec5def5eSMehdi Amini# RUN: %PYTHON %s | FileCheck %s 27c850867SMaksim Leventalfrom functools import partialmethod 3ec5def5eSMehdi Amini 4ec5def5eSMehdi Aminifrom mlir.ir import * 5ec5def5eSMehdi Aminiimport mlir.dialects.arith as arith 67c850867SMaksim Leventalimport mlir.dialects.func as func 7*5d59fa90SOleksandr "Alex" Zinenkofrom array import array 8ec5def5eSMehdi Amini 9f9008e63STobias Hieta 10ec5def5eSMehdi Aminidef run(f): 11ec5def5eSMehdi Amini print("\nTEST:", f.__name__) 12ec5def5eSMehdi Amini f() 13ec5def5eSMehdi Amini 14f9008e63STobias Hieta 15ec5def5eSMehdi Amini# CHECK-LABEL: TEST: testConstantOp 16ec5def5eSMehdi Amini@run 17ec5def5eSMehdi Aminidef testConstantOps(): 18ec5def5eSMehdi Amini with Context() as ctx, Location.unknown(): 19ec5def5eSMehdi Amini module = Module.create() 20ec5def5eSMehdi Amini with InsertionPoint(module.body): 21ec5def5eSMehdi Amini arith.ConstantOp(value=42.42, result=F32Type.get()) 22ec5def5eSMehdi Amini # CHECK: %cst = arith.constant 4.242000e+01 : f32 23ec5def5eSMehdi Amini print(module) 2492233062Smax 2592233062Smax 2692233062Smax# CHECK-LABEL: TEST: testFastMathFlags 2792233062Smax@run 2892233062Smaxdef testFastMathFlags(): 2992233062Smax with Context() as ctx, Location.unknown(): 3092233062Smax module = Module.create() 3192233062Smax with InsertionPoint(module.body): 3292233062Smax a = arith.ConstantOp(value=42.42, result=F32Type.get()) 3392233062Smax r = arith.AddFOp( 3492233062Smax a, a, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf 3592233062Smax ) 3692233062Smax # CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32 3792233062Smax print(r) 38a2288a89SMaksim Levental 39a2288a89SMaksim Levental 407c850867SMaksim Levental# CHECK-LABEL: TEST: testArithValue 41a2288a89SMaksim Levental@run 427c850867SMaksim Leventaldef testArithValue(): 437c850867SMaksim Levental def _binary_op(lhs, rhs, op: str) -> "ArithValue": 447c850867SMaksim Levental op = op.capitalize() 457c850867SMaksim Levental if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type): 467c850867SMaksim Levental op += "F" 477c850867SMaksim Levental elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type( 487c850867SMaksim Levental lhs.type 497c850867SMaksim Levental ): 507c850867SMaksim Levental op += "I" 517c850867SMaksim Levental else: 527c850867SMaksim Levental raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}") 537c850867SMaksim Levental 547c850867SMaksim Levental op = getattr(arith, f"{op}Op") 557c850867SMaksim Levental return op(lhs, rhs).result 567c850867SMaksim Levental 577c850867SMaksim Levental @register_value_caster(F16Type.static_typeid) 587c850867SMaksim Levental @register_value_caster(F32Type.static_typeid) 597c850867SMaksim Levental @register_value_caster(F64Type.static_typeid) 607c850867SMaksim Levental @register_value_caster(IntegerType.static_typeid) 617c850867SMaksim Levental class ArithValue(Value): 627c850867SMaksim Levental def __init__(self, v): 637c850867SMaksim Levental super().__init__(v) 647c850867SMaksim Levental 657c850867SMaksim Levental __add__ = partialmethod(_binary_op, op="add") 667c850867SMaksim Levental __sub__ = partialmethod(_binary_op, op="sub") 677c850867SMaksim Levental __mul__ = partialmethod(_binary_op, op="mul") 687c850867SMaksim Levental 697c850867SMaksim Levental def __str__(self): 707c850867SMaksim Levental return super().__str__().replace(Value.__name__, ArithValue.__name__) 717c850867SMaksim Levental 72a2288a89SMaksim Levental with Context() as ctx, Location.unknown(): 73a2288a89SMaksim Levental module = Module.create() 747c850867SMaksim Levental f16_t = F16Type.get() 75a2288a89SMaksim Levental f32_t = F32Type.get() 767c850867SMaksim Levental f64_t = F64Type.get() 77a2288a89SMaksim Levental 78a2288a89SMaksim Levental with InsertionPoint(module.body): 79537b2aa2SMaksim Levental a = arith.constant(f16_t, 42.42) 807c850867SMaksim Levental # CHECK: ArithValue(%cst = arith.constant 4.240 81a2288a89SMaksim Levental print(a) 827c850867SMaksim Levental 837c850867SMaksim Levental b = a + a 847c850867SMaksim Levental # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16) 857c850867SMaksim Levental print(b) 867c850867SMaksim Levental 87537b2aa2SMaksim Levental a = arith.constant(f32_t, 42.42) 887c850867SMaksim Levental b = a - a 897c850867SMaksim Levental # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32) 907c850867SMaksim Levental print(b) 917c850867SMaksim Levental 92537b2aa2SMaksim Levental a = arith.constant(f64_t, 42.42) 937c850867SMaksim Levental b = a * a 947c850867SMaksim Levental # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64) 957c850867SMaksim Levental print(b) 96*5d59fa90SOleksandr "Alex" Zinenko 97*5d59fa90SOleksandr "Alex" Zinenko 98*5d59fa90SOleksandr "Alex" Zinenko# CHECK-LABEL: TEST: testArrayConstantConstruction 99*5d59fa90SOleksandr "Alex" Zinenko@run 100*5d59fa90SOleksandr "Alex" Zinenkodef testArrayConstantConstruction(): 101*5d59fa90SOleksandr "Alex" Zinenko with Context(), Location.unknown(): 102*5d59fa90SOleksandr "Alex" Zinenko module = Module.create() 103*5d59fa90SOleksandr "Alex" Zinenko with InsertionPoint(module.body): 104*5d59fa90SOleksandr "Alex" Zinenko i32_array = array("i", [1, 2, 3, 4]) 105*5d59fa90SOleksandr "Alex" Zinenko i32 = IntegerType.get_signless(32) 106*5d59fa90SOleksandr "Alex" Zinenko vec_i32 = VectorType.get([2, 2], i32) 107*5d59fa90SOleksandr "Alex" Zinenko arith.constant(vec_i32, i32_array) 108*5d59fa90SOleksandr "Alex" Zinenko arith.ConstantOp(vec_i32, DenseIntElementsAttr.get(i32_array, type=vec_i32)) 109*5d59fa90SOleksandr "Alex" Zinenko 110*5d59fa90SOleksandr "Alex" Zinenko # "q" is the equivalent of `long long` in C and requires at least 111*5d59fa90SOleksandr "Alex" Zinenko # 64 bit width integers on both Linux and Windows. 112*5d59fa90SOleksandr "Alex" Zinenko i64_array = array("q", [5, 6, 7, 8]) 113*5d59fa90SOleksandr "Alex" Zinenko i64 = IntegerType.get_signless(64) 114*5d59fa90SOleksandr "Alex" Zinenko vec_i64 = VectorType.get([1, 4], i64) 115*5d59fa90SOleksandr "Alex" Zinenko arith.constant(vec_i64, i64_array) 116*5d59fa90SOleksandr "Alex" Zinenko arith.ConstantOp(vec_i64, DenseIntElementsAttr.get(i64_array, type=vec_i64)) 117*5d59fa90SOleksandr "Alex" Zinenko 118*5d59fa90SOleksandr "Alex" Zinenko f32_array = array("f", [1.0, 2.0, 3.0, 4.0]) 119*5d59fa90SOleksandr "Alex" Zinenko f32 = F32Type.get() 120*5d59fa90SOleksandr "Alex" Zinenko vec_f32 = VectorType.get([4, 1], f32) 121*5d59fa90SOleksandr "Alex" Zinenko arith.constant(vec_f32, f32_array) 122*5d59fa90SOleksandr "Alex" Zinenko arith.ConstantOp(vec_f32, DenseFPElementsAttr.get(f32_array, type=vec_f32)) 123*5d59fa90SOleksandr "Alex" Zinenko 124*5d59fa90SOleksandr "Alex" Zinenko f64_array = array("d", [1.0, 2.0, 3.0, 4.0]) 125*5d59fa90SOleksandr "Alex" Zinenko f64 = F64Type.get() 126*5d59fa90SOleksandr "Alex" Zinenko vec_f64 = VectorType.get([2, 1, 2], f64) 127*5d59fa90SOleksandr "Alex" Zinenko arith.constant(vec_f64, f64_array) 128*5d59fa90SOleksandr "Alex" Zinenko arith.ConstantOp(vec_f64, DenseFPElementsAttr.get(f64_array, type=vec_f64)) 129*5d59fa90SOleksandr "Alex" Zinenko 130*5d59fa90SOleksandr "Alex" Zinenko # CHECK-COUNT-2: arith.constant dense<[{{\[}}1, 2], [3, 4]]> : vector<2x2xi32> 131*5d59fa90SOleksandr "Alex" Zinenko # CHECK-COUNT-2: arith.constant dense<[{{\[}}5, 6, 7, 8]]> : vector<1x4xi64> 132*5d59fa90SOleksandr "Alex" Zinenko # CHECK-COUNT-2: arith.constant dense<[{{\[}}1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00]]> : vector<4x1xf32> 133*5d59fa90SOleksandr "Alex" Zinenko # CHECK-COUNT-2: arith.constant dense<[{{\[}}[1.000000e+00, 2.000000e+00]], [{{\[}}3.000000e+00, 4.000000e+00]]]> : vector<2x1x2xf64> 134*5d59fa90SOleksandr "Alex" Zinenko print(module) 135