1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import quant 5 6 7def run(f): 8 print("\nTEST:", f.__name__) 9 f() 10 return f 11 12 13# CHECK-LABEL: TEST: test_type_hierarchy 14@run 15def test_type_hierarchy(): 16 with Context(): 17 i8 = IntegerType.get_signless(8) 18 any = Type.parse("!quant.any<i8<-8:7>:f32>") 19 uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>") 20 per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>") 21 calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>") 22 23 assert not quant.QuantizedType.isinstance(i8) 24 assert quant.QuantizedType.isinstance(any) 25 assert quant.QuantizedType.isinstance(uniform) 26 assert quant.QuantizedType.isinstance(per_axis) 27 assert quant.QuantizedType.isinstance(calibrated) 28 29 assert quant.AnyQuantizedType.isinstance(any) 30 assert quant.UniformQuantizedType.isinstance(uniform) 31 assert quant.UniformQuantizedPerAxisType.isinstance(per_axis) 32 assert quant.CalibratedQuantizedType.isinstance(calibrated) 33 34 assert not quant.AnyQuantizedType.isinstance(uniform) 35 assert not quant.UniformQuantizedType.isinstance(per_axis) 36 37 38# CHECK-LABEL: TEST: test_any_quantized_type 39@run 40def test_any_quantized_type(): 41 with Context(): 42 i8 = IntegerType.get_signless(8) 43 f32 = F32Type.get() 44 any = quant.AnyQuantizedType.get( 45 quant.QuantizedType.FLAG_SIGNED, i8, f32, -8, 7 46 ) 47 48 # CHECK: flags: 1 49 print(f"flags: {any.flags}") 50 # CHECK: signed: True 51 print(f"signed: {any.is_signed}") 52 # CHECK: storage type: i8 53 print(f"storage type: {any.storage_type}") 54 # CHECK: expressed type: f32 55 print(f"expressed type: {any.expressed_type}") 56 # CHECK: storage min: -8 57 print(f"storage min: {any.storage_type_min}") 58 # CHECK: storage max: 7 59 print(f"storage max: {any.storage_type_max}") 60 # CHECK: storage width: 8 61 print(f"storage width: {any.storage_type_integral_width}") 62 # CHECK: quantized element type: !quant.any<i8<-8:7>:f32> 63 print(f"quantized element type: {any.quantized_element_type}") 64 # CHECK: !quant.any<i8<-8:7>:f32> 65 print(any) 66 assert any == Type.parse("!quant.any<i8<-8:7>:f32>") 67 68 69# CHECK-LABEL: TEST: test_uniform_type 70@run 71def test_uniform_type(): 72 with Context(): 73 i8 = IntegerType.get_signless(8) 74 f32 = F32Type.get() 75 uniform = quant.UniformQuantizedType.get( 76 quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7 77 ) 78 79 # CHECK: scale: 0.99872 80 print(f"scale: {uniform.scale}") 81 # CHECK: zero point: 127 82 print(f"zero point: {uniform.zero_point}") 83 # CHECK: fixed point: False 84 print(f"fixed point: {uniform.is_fixed_point}") 85 # CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127> 86 print(uniform) 87 assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>") 88 89 90# CHECK-LABEL: TEST: test_uniform_per_axis_type 91@run 92def test_uniform_per_axis_type(): 93 with Context(): 94 i8 = IntegerType.get_signless(8) 95 f32 = F32Type.get() 96 per_axis = quant.UniformQuantizedPerAxisType.get( 97 quant.QuantizedType.FLAG_SIGNED, 98 i8, 99 f32, 100 [200, 0.99872], 101 [0, 120], 102 quantized_dimension=1, 103 storage_type_min=quant.QuantizedType.default_minimum_for_integer( 104 is_signed=True, integral_width=8 105 ), 106 storage_type_max=quant.QuantizedType.default_maximum_for_integer( 107 is_signed=True, integral_width=8 108 ), 109 ) 110 111 # CHECK: scales: [200.0, 0.99872] 112 print(f"scales: {per_axis.scales}") 113 # CHECK: zero_points: [0, 120] 114 print(f"zero_points: {per_axis.zero_points}") 115 # CHECK: quantized dim: 1 116 print(f"quantized dim: {per_axis.quantized_dimension}") 117 # CHECK: fixed point: False 118 print(f"fixed point: {per_axis.is_fixed_point}") 119 # CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}> 120 print(per_axis) 121 assert per_axis == Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>") 122 123 124# CHECK-LABEL: TEST: test_calibrated_type 125@run 126def test_calibrated_type(): 127 with Context(): 128 f32 = F32Type.get() 129 calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321) 130 131 # CHECK: min: -0.998 132 print(f"min: {calibrated.min}") 133 # CHECK: max: 1.2321 134 print(f"max: {calibrated.max}") 135 # CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>> 136 print(calibrated) 137 assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>") 138