166d4090dSAlex Zinenko# RUN: %PYTHON %s | FileCheck %s 266d4090dSAlex Zinenko 366d4090dSAlex Zinenkofrom mlir.ir import * 466d4090dSAlex Zinenkofrom mlir.dialects import quant 566d4090dSAlex Zinenko 666d4090dSAlex Zinenko 766d4090dSAlex Zinenkodef run(f): 866d4090dSAlex Zinenko print("\nTEST:", f.__name__) 966d4090dSAlex Zinenko f() 1066d4090dSAlex Zinenko return f 1166d4090dSAlex Zinenko 1266d4090dSAlex Zinenko 1366d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_type_hierarchy 1466d4090dSAlex Zinenko@run 1566d4090dSAlex Zinenkodef test_type_hierarchy(): 1666d4090dSAlex Zinenko with Context(): 1766d4090dSAlex Zinenko i8 = IntegerType.get_signless(8) 1866d4090dSAlex Zinenko any = Type.parse("!quant.any<i8<-8:7>:f32>") 1966d4090dSAlex Zinenko uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>") 2066d4090dSAlex Zinenko per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>") 2166d4090dSAlex Zinenko calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>") 2266d4090dSAlex Zinenko 2366d4090dSAlex Zinenko assert not quant.QuantizedType.isinstance(i8) 2466d4090dSAlex Zinenko assert quant.QuantizedType.isinstance(any) 2566d4090dSAlex Zinenko assert quant.QuantizedType.isinstance(uniform) 2666d4090dSAlex Zinenko assert quant.QuantizedType.isinstance(per_axis) 2766d4090dSAlex Zinenko assert quant.QuantizedType.isinstance(calibrated) 2866d4090dSAlex Zinenko 2966d4090dSAlex Zinenko assert quant.AnyQuantizedType.isinstance(any) 3066d4090dSAlex Zinenko assert quant.UniformQuantizedType.isinstance(uniform) 3166d4090dSAlex Zinenko assert quant.UniformQuantizedPerAxisType.isinstance(per_axis) 3266d4090dSAlex Zinenko assert quant.CalibratedQuantizedType.isinstance(calibrated) 3366d4090dSAlex Zinenko 3466d4090dSAlex Zinenko assert not quant.AnyQuantizedType.isinstance(uniform) 3566d4090dSAlex Zinenko assert not quant.UniformQuantizedType.isinstance(per_axis) 3666d4090dSAlex Zinenko 3766d4090dSAlex Zinenko 3866d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_any_quantized_type 3966d4090dSAlex Zinenko@run 4066d4090dSAlex Zinenkodef test_any_quantized_type(): 4166d4090dSAlex Zinenko with Context(): 4266d4090dSAlex Zinenko i8 = IntegerType.get_signless(8) 4366d4090dSAlex Zinenko f32 = F32Type.get() 44f9008e63STobias Hieta any = quant.AnyQuantizedType.get( 45f9008e63STobias Hieta quant.QuantizedType.FLAG_SIGNED, i8, f32, -8, 7 46f9008e63STobias Hieta ) 4766d4090dSAlex Zinenko 4866d4090dSAlex Zinenko # CHECK: flags: 1 4966d4090dSAlex Zinenko print(f"flags: {any.flags}") 5066d4090dSAlex Zinenko # CHECK: signed: True 5166d4090dSAlex Zinenko print(f"signed: {any.is_signed}") 5266d4090dSAlex Zinenko # CHECK: storage type: i8 5366d4090dSAlex Zinenko print(f"storage type: {any.storage_type}") 5466d4090dSAlex Zinenko # CHECK: expressed type: f32 5566d4090dSAlex Zinenko print(f"expressed type: {any.expressed_type}") 5666d4090dSAlex Zinenko # CHECK: storage min: -8 5766d4090dSAlex Zinenko print(f"storage min: {any.storage_type_min}") 5866d4090dSAlex Zinenko # CHECK: storage max: 7 5966d4090dSAlex Zinenko print(f"storage max: {any.storage_type_max}") 6066d4090dSAlex Zinenko # CHECK: storage width: 8 6166d4090dSAlex Zinenko print(f"storage width: {any.storage_type_integral_width}") 6266d4090dSAlex Zinenko # CHECK: quantized element type: !quant.any<i8<-8:7>:f32> 6366d4090dSAlex Zinenko print(f"quantized element type: {any.quantized_element_type}") 6466d4090dSAlex Zinenko # CHECK: !quant.any<i8<-8:7>:f32> 6566d4090dSAlex Zinenko print(any) 6666d4090dSAlex Zinenko assert any == Type.parse("!quant.any<i8<-8:7>:f32>") 6766d4090dSAlex Zinenko 6866d4090dSAlex Zinenko 6966d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_uniform_type 7066d4090dSAlex Zinenko@run 7166d4090dSAlex Zinenkodef test_uniform_type(): 7266d4090dSAlex Zinenko with Context(): 7366d4090dSAlex Zinenko i8 = IntegerType.get_signless(8) 7466d4090dSAlex Zinenko f32 = F32Type.get() 7566d4090dSAlex Zinenko uniform = quant.UniformQuantizedType.get( 76f9008e63STobias Hieta quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7 77f9008e63STobias Hieta ) 7866d4090dSAlex Zinenko 7966d4090dSAlex Zinenko # CHECK: scale: 0.99872 8066d4090dSAlex Zinenko print(f"scale: {uniform.scale}") 8166d4090dSAlex Zinenko # CHECK: zero point: 127 8266d4090dSAlex Zinenko print(f"zero point: {uniform.zero_point}") 8366d4090dSAlex Zinenko # CHECK: fixed point: False 8466d4090dSAlex Zinenko print(f"fixed point: {uniform.is_fixed_point}") 8566d4090dSAlex Zinenko # CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127> 8666d4090dSAlex Zinenko print(uniform) 8766d4090dSAlex Zinenko assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>") 8866d4090dSAlex Zinenko 8966d4090dSAlex Zinenko 9066d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_uniform_per_axis_type 9166d4090dSAlex Zinenko@run 9266d4090dSAlex Zinenkodef test_uniform_per_axis_type(): 9366d4090dSAlex Zinenko with Context(): 9466d4090dSAlex Zinenko i8 = IntegerType.get_signless(8) 9566d4090dSAlex Zinenko f32 = F32Type.get() 9666d4090dSAlex Zinenko per_axis = quant.UniformQuantizedPerAxisType.get( 9766d4090dSAlex Zinenko quant.QuantizedType.FLAG_SIGNED, 9866d4090dSAlex Zinenko i8, 99f9008e63STobias Hieta f32, 100f9008e63STobias Hieta [200, 0.99872], 101f9008e63STobias Hieta [0, 120], 10266d4090dSAlex Zinenko quantized_dimension=1, 10366d4090dSAlex Zinenko storage_type_min=quant.QuantizedType.default_minimum_for_integer( 104f9008e63STobias Hieta is_signed=True, integral_width=8 105f9008e63STobias Hieta ), 10666d4090dSAlex Zinenko storage_type_max=quant.QuantizedType.default_maximum_for_integer( 107f9008e63STobias Hieta is_signed=True, integral_width=8 108f9008e63STobias Hieta ), 109f9008e63STobias Hieta ) 11066d4090dSAlex Zinenko 111*47ef5c4bSannuasd # CHECK: scales: [200.0, 0.99872] 11266d4090dSAlex Zinenko print(f"scales: {per_axis.scales}") 113*47ef5c4bSannuasd # CHECK: zero_points: [0, 120] 11466d4090dSAlex Zinenko print(f"zero_points: {per_axis.zero_points}") 11566d4090dSAlex Zinenko # CHECK: quantized dim: 1 11666d4090dSAlex Zinenko print(f"quantized dim: {per_axis.quantized_dimension}") 11766d4090dSAlex Zinenko # CHECK: fixed point: False 11866d4090dSAlex Zinenko print(f"fixed point: {per_axis.is_fixed_point}") 11966d4090dSAlex Zinenko # CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}> 12066d4090dSAlex Zinenko print(per_axis) 121f9008e63STobias Hieta assert per_axis == Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>") 12266d4090dSAlex Zinenko 12366d4090dSAlex Zinenko 12466d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_calibrated_type 12566d4090dSAlex Zinenko@run 12666d4090dSAlex Zinenkodef test_calibrated_type(): 12766d4090dSAlex Zinenko with Context(): 12866d4090dSAlex Zinenko f32 = F32Type.get() 12966d4090dSAlex Zinenko calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321) 13066d4090dSAlex Zinenko 13166d4090dSAlex Zinenko # CHECK: min: -0.998 13266d4090dSAlex Zinenko print(f"min: {calibrated.min}") 13366d4090dSAlex Zinenko # CHECK: max: 1.2321 13466d4090dSAlex Zinenko print(f"max: {calibrated.max}") 13566d4090dSAlex Zinenko # CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>> 13666d4090dSAlex Zinenko print(calibrated) 13766d4090dSAlex Zinenko assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>") 138