xref: /llvm-project/mlir/test/python/dialects/quant.py (revision 47ef5c4b7f85bc1c8a859d721db9fd1dde7b8d8e)
166d4090dSAlex Zinenko# RUN: %PYTHON %s | FileCheck %s
266d4090dSAlex Zinenko
366d4090dSAlex Zinenkofrom mlir.ir import *
466d4090dSAlex Zinenkofrom mlir.dialects import quant
566d4090dSAlex Zinenko
666d4090dSAlex Zinenko
766d4090dSAlex Zinenkodef run(f):
866d4090dSAlex Zinenko    print("\nTEST:", f.__name__)
966d4090dSAlex Zinenko    f()
1066d4090dSAlex Zinenko    return f
1166d4090dSAlex Zinenko
1266d4090dSAlex Zinenko
1366d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_type_hierarchy
1466d4090dSAlex Zinenko@run
1566d4090dSAlex Zinenkodef test_type_hierarchy():
1666d4090dSAlex Zinenko    with Context():
1766d4090dSAlex Zinenko        i8 = IntegerType.get_signless(8)
1866d4090dSAlex Zinenko        any = Type.parse("!quant.any<i8<-8:7>:f32>")
1966d4090dSAlex Zinenko        uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
2066d4090dSAlex Zinenko        per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
2166d4090dSAlex Zinenko        calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
2266d4090dSAlex Zinenko
2366d4090dSAlex Zinenko        assert not quant.QuantizedType.isinstance(i8)
2466d4090dSAlex Zinenko        assert quant.QuantizedType.isinstance(any)
2566d4090dSAlex Zinenko        assert quant.QuantizedType.isinstance(uniform)
2666d4090dSAlex Zinenko        assert quant.QuantizedType.isinstance(per_axis)
2766d4090dSAlex Zinenko        assert quant.QuantizedType.isinstance(calibrated)
2866d4090dSAlex Zinenko
2966d4090dSAlex Zinenko        assert quant.AnyQuantizedType.isinstance(any)
3066d4090dSAlex Zinenko        assert quant.UniformQuantizedType.isinstance(uniform)
3166d4090dSAlex Zinenko        assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
3266d4090dSAlex Zinenko        assert quant.CalibratedQuantizedType.isinstance(calibrated)
3366d4090dSAlex Zinenko
3466d4090dSAlex Zinenko        assert not quant.AnyQuantizedType.isinstance(uniform)
3566d4090dSAlex Zinenko        assert not quant.UniformQuantizedType.isinstance(per_axis)
3666d4090dSAlex Zinenko
3766d4090dSAlex Zinenko
3866d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_any_quantized_type
3966d4090dSAlex Zinenko@run
4066d4090dSAlex Zinenkodef test_any_quantized_type():
4166d4090dSAlex Zinenko    with Context():
4266d4090dSAlex Zinenko        i8 = IntegerType.get_signless(8)
4366d4090dSAlex Zinenko        f32 = F32Type.get()
44f9008e63STobias Hieta        any = quant.AnyQuantizedType.get(
45f9008e63STobias Hieta            quant.QuantizedType.FLAG_SIGNED, i8, f32, -8, 7
46f9008e63STobias Hieta        )
4766d4090dSAlex Zinenko
4866d4090dSAlex Zinenko        # CHECK: flags: 1
4966d4090dSAlex Zinenko        print(f"flags: {any.flags}")
5066d4090dSAlex Zinenko        # CHECK: signed: True
5166d4090dSAlex Zinenko        print(f"signed: {any.is_signed}")
5266d4090dSAlex Zinenko        # CHECK: storage type: i8
5366d4090dSAlex Zinenko        print(f"storage type: {any.storage_type}")
5466d4090dSAlex Zinenko        # CHECK: expressed type: f32
5566d4090dSAlex Zinenko        print(f"expressed type: {any.expressed_type}")
5666d4090dSAlex Zinenko        # CHECK: storage min: -8
5766d4090dSAlex Zinenko        print(f"storage min: {any.storage_type_min}")
5866d4090dSAlex Zinenko        # CHECK: storage max: 7
5966d4090dSAlex Zinenko        print(f"storage max: {any.storage_type_max}")
6066d4090dSAlex Zinenko        # CHECK: storage width: 8
6166d4090dSAlex Zinenko        print(f"storage width: {any.storage_type_integral_width}")
6266d4090dSAlex Zinenko        # CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
6366d4090dSAlex Zinenko        print(f"quantized element type: {any.quantized_element_type}")
6466d4090dSAlex Zinenko        # CHECK: !quant.any<i8<-8:7>:f32>
6566d4090dSAlex Zinenko        print(any)
6666d4090dSAlex Zinenko        assert any == Type.parse("!quant.any<i8<-8:7>:f32>")
6766d4090dSAlex Zinenko
6866d4090dSAlex Zinenko
6966d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_uniform_type
7066d4090dSAlex Zinenko@run
7166d4090dSAlex Zinenkodef test_uniform_type():
7266d4090dSAlex Zinenko    with Context():
7366d4090dSAlex Zinenko        i8 = IntegerType.get_signless(8)
7466d4090dSAlex Zinenko        f32 = F32Type.get()
7566d4090dSAlex Zinenko        uniform = quant.UniformQuantizedType.get(
76f9008e63STobias Hieta            quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7
77f9008e63STobias Hieta        )
7866d4090dSAlex Zinenko
7966d4090dSAlex Zinenko        # CHECK: scale: 0.99872
8066d4090dSAlex Zinenko        print(f"scale: {uniform.scale}")
8166d4090dSAlex Zinenko        # CHECK: zero point: 127
8266d4090dSAlex Zinenko        print(f"zero point: {uniform.zero_point}")
8366d4090dSAlex Zinenko        # CHECK: fixed point: False
8466d4090dSAlex Zinenko        print(f"fixed point: {uniform.is_fixed_point}")
8566d4090dSAlex Zinenko        # CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
8666d4090dSAlex Zinenko        print(uniform)
8766d4090dSAlex Zinenko        assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
8866d4090dSAlex Zinenko
8966d4090dSAlex Zinenko
9066d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_uniform_per_axis_type
9166d4090dSAlex Zinenko@run
9266d4090dSAlex Zinenkodef test_uniform_per_axis_type():
9366d4090dSAlex Zinenko    with Context():
9466d4090dSAlex Zinenko        i8 = IntegerType.get_signless(8)
9566d4090dSAlex Zinenko        f32 = F32Type.get()
9666d4090dSAlex Zinenko        per_axis = quant.UniformQuantizedPerAxisType.get(
9766d4090dSAlex Zinenko            quant.QuantizedType.FLAG_SIGNED,
9866d4090dSAlex Zinenko            i8,
99f9008e63STobias Hieta            f32,
100f9008e63STobias Hieta            [200, 0.99872],
101f9008e63STobias Hieta            [0, 120],
10266d4090dSAlex Zinenko            quantized_dimension=1,
10366d4090dSAlex Zinenko            storage_type_min=quant.QuantizedType.default_minimum_for_integer(
104f9008e63STobias Hieta                is_signed=True, integral_width=8
105f9008e63STobias Hieta            ),
10666d4090dSAlex Zinenko            storage_type_max=quant.QuantizedType.default_maximum_for_integer(
107f9008e63STobias Hieta                is_signed=True, integral_width=8
108f9008e63STobias Hieta            ),
109f9008e63STobias Hieta        )
11066d4090dSAlex Zinenko
111*47ef5c4bSannuasd        # CHECK: scales: [200.0, 0.99872]
11266d4090dSAlex Zinenko        print(f"scales: {per_axis.scales}")
113*47ef5c4bSannuasd        # CHECK: zero_points: [0, 120]
11466d4090dSAlex Zinenko        print(f"zero_points: {per_axis.zero_points}")
11566d4090dSAlex Zinenko        # CHECK: quantized dim: 1
11666d4090dSAlex Zinenko        print(f"quantized dim: {per_axis.quantized_dimension}")
11766d4090dSAlex Zinenko        # CHECK: fixed point: False
11866d4090dSAlex Zinenko        print(f"fixed point: {per_axis.is_fixed_point}")
11966d4090dSAlex Zinenko        # CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
12066d4090dSAlex Zinenko        print(per_axis)
121f9008e63STobias Hieta        assert per_axis == Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
12266d4090dSAlex Zinenko
12366d4090dSAlex Zinenko
12466d4090dSAlex Zinenko# CHECK-LABEL: TEST: test_calibrated_type
12566d4090dSAlex Zinenko@run
12666d4090dSAlex Zinenkodef test_calibrated_type():
12766d4090dSAlex Zinenko    with Context():
12866d4090dSAlex Zinenko        f32 = F32Type.get()
12966d4090dSAlex Zinenko        calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321)
13066d4090dSAlex Zinenko
13166d4090dSAlex Zinenko        # CHECK: min: -0.998
13266d4090dSAlex Zinenko        print(f"min: {calibrated.min}")
13366d4090dSAlex Zinenko        # CHECK: max: 1.2321
13466d4090dSAlex Zinenko        print(f"max: {calibrated.max}")
13566d4090dSAlex Zinenko        # CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
13666d4090dSAlex Zinenko        print(calibrated)
13766d4090dSAlex Zinenko        assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
138