xref: /llvm-project/mlir/test/python/dialects/quant.py (revision 47ef5c4b7f85bc1c8a859d721db9fd1dde7b8d8e)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import quant
5
6
7def run(f):
8    print("\nTEST:", f.__name__)
9    f()
10    return f
11
12
13# CHECK-LABEL: TEST: test_type_hierarchy
14@run
15def test_type_hierarchy():
16    with Context():
17        i8 = IntegerType.get_signless(8)
18        any = Type.parse("!quant.any<i8<-8:7>:f32>")
19        uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
20        per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
21        calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
22
23        assert not quant.QuantizedType.isinstance(i8)
24        assert quant.QuantizedType.isinstance(any)
25        assert quant.QuantizedType.isinstance(uniform)
26        assert quant.QuantizedType.isinstance(per_axis)
27        assert quant.QuantizedType.isinstance(calibrated)
28
29        assert quant.AnyQuantizedType.isinstance(any)
30        assert quant.UniformQuantizedType.isinstance(uniform)
31        assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
32        assert quant.CalibratedQuantizedType.isinstance(calibrated)
33
34        assert not quant.AnyQuantizedType.isinstance(uniform)
35        assert not quant.UniformQuantizedType.isinstance(per_axis)
36
37
38# CHECK-LABEL: TEST: test_any_quantized_type
39@run
40def test_any_quantized_type():
41    with Context():
42        i8 = IntegerType.get_signless(8)
43        f32 = F32Type.get()
44        any = quant.AnyQuantizedType.get(
45            quant.QuantizedType.FLAG_SIGNED, i8, f32, -8, 7
46        )
47
48        # CHECK: flags: 1
49        print(f"flags: {any.flags}")
50        # CHECK: signed: True
51        print(f"signed: {any.is_signed}")
52        # CHECK: storage type: i8
53        print(f"storage type: {any.storage_type}")
54        # CHECK: expressed type: f32
55        print(f"expressed type: {any.expressed_type}")
56        # CHECK: storage min: -8
57        print(f"storage min: {any.storage_type_min}")
58        # CHECK: storage max: 7
59        print(f"storage max: {any.storage_type_max}")
60        # CHECK: storage width: 8
61        print(f"storage width: {any.storage_type_integral_width}")
62        # CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
63        print(f"quantized element type: {any.quantized_element_type}")
64        # CHECK: !quant.any<i8<-8:7>:f32>
65        print(any)
66        assert any == Type.parse("!quant.any<i8<-8:7>:f32>")
67
68
69# CHECK-LABEL: TEST: test_uniform_type
70@run
71def test_uniform_type():
72    with Context():
73        i8 = IntegerType.get_signless(8)
74        f32 = F32Type.get()
75        uniform = quant.UniformQuantizedType.get(
76            quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7
77        )
78
79        # CHECK: scale: 0.99872
80        print(f"scale: {uniform.scale}")
81        # CHECK: zero point: 127
82        print(f"zero point: {uniform.zero_point}")
83        # CHECK: fixed point: False
84        print(f"fixed point: {uniform.is_fixed_point}")
85        # CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
86        print(uniform)
87        assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
88
89
90# CHECK-LABEL: TEST: test_uniform_per_axis_type
91@run
92def test_uniform_per_axis_type():
93    with Context():
94        i8 = IntegerType.get_signless(8)
95        f32 = F32Type.get()
96        per_axis = quant.UniformQuantizedPerAxisType.get(
97            quant.QuantizedType.FLAG_SIGNED,
98            i8,
99            f32,
100            [200, 0.99872],
101            [0, 120],
102            quantized_dimension=1,
103            storage_type_min=quant.QuantizedType.default_minimum_for_integer(
104                is_signed=True, integral_width=8
105            ),
106            storage_type_max=quant.QuantizedType.default_maximum_for_integer(
107                is_signed=True, integral_width=8
108            ),
109        )
110
111        # CHECK: scales: [200.0, 0.99872]
112        print(f"scales: {per_axis.scales}")
113        # CHECK: zero_points: [0, 120]
114        print(f"zero_points: {per_axis.zero_points}")
115        # CHECK: quantized dim: 1
116        print(f"quantized dim: {per_axis.quantized_dimension}")
117        # CHECK: fixed point: False
118        print(f"fixed point: {per_axis.is_fixed_point}")
119        # CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
120        print(per_axis)
121        assert per_axis == Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
122
123
124# CHECK-LABEL: TEST: test_calibrated_type
125@run
126def test_calibrated_type():
127    with Context():
128        f32 = F32Type.get()
129        calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321)
130
131        # CHECK: min: -0.998
132        print(f"min: {calibrated.min}")
133        # CHECK: max: 1.2321
134        print(f"max: {calibrated.max}")
135        # CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
136        print(calibrated)
137        assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
138