1# RUN: %PYTHON %s | FileCheck %s 2from functools import partialmethod 3 4from mlir.ir import * 5import mlir.dialects.arith as arith 6import mlir.dialects.func as func 7from array import array 8 9 10def run(f): 11 print("\nTEST:", f.__name__) 12 f() 13 14 15# CHECK-LABEL: TEST: testConstantOp 16@run 17def testConstantOps(): 18 with Context() as ctx, Location.unknown(): 19 module = Module.create() 20 with InsertionPoint(module.body): 21 arith.ConstantOp(value=42.42, result=F32Type.get()) 22 # CHECK: %cst = arith.constant 4.242000e+01 : f32 23 print(module) 24 25 26# CHECK-LABEL: TEST: testFastMathFlags 27@run 28def testFastMathFlags(): 29 with Context() as ctx, Location.unknown(): 30 module = Module.create() 31 with InsertionPoint(module.body): 32 a = arith.ConstantOp(value=42.42, result=F32Type.get()) 33 r = arith.AddFOp( 34 a, a, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf 35 ) 36 # CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32 37 print(r) 38 39 40# CHECK-LABEL: TEST: testArithValue 41@run 42def testArithValue(): 43 def _binary_op(lhs, rhs, op: str) -> "ArithValue": 44 op = op.capitalize() 45 if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type): 46 op += "F" 47 elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type( 48 lhs.type 49 ): 50 op += "I" 51 else: 52 raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}") 53 54 op = getattr(arith, f"{op}Op") 55 return op(lhs, rhs).result 56 57 @register_value_caster(F16Type.static_typeid) 58 @register_value_caster(F32Type.static_typeid) 59 @register_value_caster(F64Type.static_typeid) 60 @register_value_caster(IntegerType.static_typeid) 61 class ArithValue(Value): 62 def __init__(self, v): 63 super().__init__(v) 64 65 __add__ = partialmethod(_binary_op, op="add") 66 __sub__ = partialmethod(_binary_op, op="sub") 67 __mul__ = partialmethod(_binary_op, op="mul") 68 69 def __str__(self): 70 return super().__str__().replace(Value.__name__, ArithValue.__name__) 71 72 with Context() as ctx, Location.unknown(): 73 module = Module.create() 74 f16_t = F16Type.get() 75 f32_t = F32Type.get() 76 f64_t = F64Type.get() 77 78 with InsertionPoint(module.body): 79 a = arith.constant(f16_t, 42.42) 80 # CHECK: ArithValue(%cst = arith.constant 4.240 81 print(a) 82 83 b = a + a 84 # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16) 85 print(b) 86 87 a = arith.constant(f32_t, 42.42) 88 b = a - a 89 # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32) 90 print(b) 91 92 a = arith.constant(f64_t, 42.42) 93 b = a * a 94 # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64) 95 print(b) 96 97 98# CHECK-LABEL: TEST: testArrayConstantConstruction 99@run 100def testArrayConstantConstruction(): 101 with Context(), Location.unknown(): 102 module = Module.create() 103 with InsertionPoint(module.body): 104 i32_array = array("i", [1, 2, 3, 4]) 105 i32 = IntegerType.get_signless(32) 106 vec_i32 = VectorType.get([2, 2], i32) 107 arith.constant(vec_i32, i32_array) 108 arith.ConstantOp(vec_i32, DenseIntElementsAttr.get(i32_array, type=vec_i32)) 109 110 # "q" is the equivalent of `long long` in C and requires at least 111 # 64 bit width integers on both Linux and Windows. 112 i64_array = array("q", [5, 6, 7, 8]) 113 i64 = IntegerType.get_signless(64) 114 vec_i64 = VectorType.get([1, 4], i64) 115 arith.constant(vec_i64, i64_array) 116 arith.ConstantOp(vec_i64, DenseIntElementsAttr.get(i64_array, type=vec_i64)) 117 118 f32_array = array("f", [1.0, 2.0, 3.0, 4.0]) 119 f32 = F32Type.get() 120 vec_f32 = VectorType.get([4, 1], f32) 121 arith.constant(vec_f32, f32_array) 122 arith.ConstantOp(vec_f32, DenseFPElementsAttr.get(f32_array, type=vec_f32)) 123 124 f64_array = array("d", [1.0, 2.0, 3.0, 4.0]) 125 f64 = F64Type.get() 126 vec_f64 = VectorType.get([2, 1, 2], f64) 127 arith.constant(vec_f64, f64_array) 128 arith.ConstantOp(vec_f64, DenseFPElementsAttr.get(f64_array, type=vec_f64)) 129 130 # CHECK-COUNT-2: arith.constant dense<[{{\[}}1, 2], [3, 4]]> : vector<2x2xi32> 131 # CHECK-COUNT-2: arith.constant dense<[{{\[}}5, 6, 7, 8]]> : vector<1x4xi64> 132 # CHECK-COUNT-2: arith.constant dense<[{{\[}}1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00]]> : vector<4x1xf32> 133 # CHECK-COUNT-2: arith.constant dense<[{{\[}}[1.000000e+00, 2.000000e+00]], [{{\[}}3.000000e+00, 4.000000e+00]]]> : vector<2x1x2xf64> 134 print(module) 135