xref: /llvm-project/mlir/test/CAPI/quant.c (revision 5d91f79fced13604ff401e5f5a6d5c3a9062ab20)
1 //===- quant.c - Test of Quant dialect C API ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 // RUN: mlir-capi-quant-test 2>&1 | FileCheck %s
11 
12 #include "mlir-c/Dialect/Quant.h"
13 #include "mlir-c/BuiltinTypes.h"
14 #include "mlir-c/IR.h"
15 
16 #include <assert.h>
17 #include <inttypes.h>
18 #include <stdio.h>
19 #include <stdlib.h>
20 
21 // CHECK-LABEL: testTypeHierarchy
testTypeHierarchy(MlirContext ctx)22 static void testTypeHierarchy(MlirContext ctx) {
23   fprintf(stderr, "testTypeHierarchy\n");
24 
25   MlirType i8 = mlirIntegerTypeGet(ctx, 8);
26   MlirType any = mlirTypeParseGet(
27       ctx, mlirStringRefCreateFromCString("!quant.any<i8<-8:7>:f32>"));
28   MlirType uniform =
29       mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(
30                                 "!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));
31   MlirType perAxis = mlirTypeParseGet(
32       ctx, mlirStringRefCreateFromCString(
33                "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));
34   MlirType calibrated = mlirTypeParseGet(
35       ctx,
36       mlirStringRefCreateFromCString("!quant.calibrated<f32<-0.998:1.2321>>"));
37 
38   // The parser itself is checked in C++ dialect tests.
39   assert(!mlirTypeIsNull(any) && "couldn't parse AnyQuantizedType");
40   assert(!mlirTypeIsNull(uniform) && "couldn't parse UniformQuantizedType");
41   assert(!mlirTypeIsNull(perAxis) &&
42          "couldn't parse UniformQuantizedPerAxisType");
43   assert(!mlirTypeIsNull(calibrated) &&
44          "couldn't parse CalibratedQuantizedType");
45 
46   // CHECK: i8 isa QuantizedType: 0
47   fprintf(stderr, "i8 isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(i8));
48   // CHECK: any isa QuantizedType: 1
49   fprintf(stderr, "any isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(any));
50   // CHECK: uniform isa QuantizedType: 1
51   fprintf(stderr, "uniform isa QuantizedType: %d\n",
52           mlirTypeIsAQuantizedType(uniform));
53   // CHECK: perAxis isa QuantizedType: 1
54   fprintf(stderr, "perAxis isa QuantizedType: %d\n",
55           mlirTypeIsAQuantizedType(perAxis));
56   // CHECK: calibrated isa QuantizedType: 1
57   fprintf(stderr, "calibrated isa QuantizedType: %d\n",
58           mlirTypeIsAQuantizedType(calibrated));
59 
60   // CHECK: any isa AnyQuantizedType: 1
61   fprintf(stderr, "any isa AnyQuantizedType: %d\n",
62           mlirTypeIsAAnyQuantizedType(any));
63   // CHECK: uniform isa UniformQuantizedType: 1
64   fprintf(stderr, "uniform isa UniformQuantizedType: %d\n",
65           mlirTypeIsAUniformQuantizedType(uniform));
66   // CHECK: perAxis isa UniformQuantizedPerAxisType: 1
67   fprintf(stderr, "perAxis isa UniformQuantizedPerAxisType: %d\n",
68           mlirTypeIsAUniformQuantizedPerAxisType(perAxis));
69   // CHECK: calibrated isa CalibratedQuantizedType: 1
70   fprintf(stderr, "calibrated isa CalibratedQuantizedType: %d\n",
71           mlirTypeIsACalibratedQuantizedType(calibrated));
72 
73   // CHECK: perAxis isa UniformQuantizedType: 0
74   fprintf(stderr, "perAxis isa UniformQuantizedType: %d\n",
75           mlirTypeIsAUniformQuantizedType(perAxis));
76   // CHECK: uniform isa CalibratedQuantizedType: 0
77   fprintf(stderr, "uniform isa CalibratedQuantizedType: %d\n",
78           mlirTypeIsACalibratedQuantizedType(uniform));
79   fprintf(stderr, "\n");
80 }
81 
82 // CHECK-LABEL: testAnyQuantizedType
testAnyQuantizedType(MlirContext ctx)83 void testAnyQuantizedType(MlirContext ctx) {
84   fprintf(stderr, "testAnyQuantizedType\n");
85 
86   MlirType anyParsed = mlirTypeParseGet(
87       ctx, mlirStringRefCreateFromCString("!quant.any<i8<-8:7>:f32>"));
88 
89   MlirType i8 = mlirIntegerTypeGet(ctx, 8);
90   MlirType f32 = mlirF32TypeGet(ctx);
91   MlirType any =
92       mlirAnyQuantizedTypeGet(mlirQuantizedTypeGetSignedFlag(), i8, f32, -8, 7);
93 
94   // CHECK: flags: 1
95   fprintf(stderr, "flags: %u\n", mlirQuantizedTypeGetFlags(any));
96   // CHECK: signed: 1
97   fprintf(stderr, "signed: %u\n", mlirQuantizedTypeIsSigned(any));
98   // CHECK: storage type: i8
99   fprintf(stderr, "storage type: ");
100   mlirTypeDump(mlirQuantizedTypeGetStorageType(any));
101   fprintf(stderr, "\n");
102   // CHECK: expressed type: f32
103   fprintf(stderr, "expressed type: ");
104   mlirTypeDump(mlirQuantizedTypeGetExpressedType(any));
105   fprintf(stderr, "\n");
106   // CHECK: storage min: -8
107   fprintf(stderr, "storage min: %" PRId64 "\n",
108           mlirQuantizedTypeGetStorageTypeMin(any));
109   // CHECK: storage max: 7
110   fprintf(stderr, "storage max: %" PRId64 "\n",
111           mlirQuantizedTypeGetStorageTypeMax(any));
112   // CHECK: storage width: 8
113   fprintf(stderr, "storage width: %u\n",
114           mlirQuantizedTypeGetStorageTypeIntegralWidth(any));
115   // CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
116   fprintf(stderr, "quantized element type: ");
117   mlirTypeDump(mlirQuantizedTypeGetQuantizedElementType(any));
118   fprintf(stderr, "\n");
119 
120   // CHECK: equal: 1
121   fprintf(stderr, "equal: %d\n", mlirTypeEqual(anyParsed, any));
122   // CHECK: !quant.any<i8<-8:7>:f32>
123   mlirTypeDump(any);
124   fprintf(stderr, "\n\n");
125 }
126 
127 // CHECK-LABEL: testUniformType
testUniformType(MlirContext ctx)128 void testUniformType(MlirContext ctx) {
129   fprintf(stderr, "testUniformType\n");
130 
131   MlirType uniformParsed =
132       mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(
133                                 "!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));
134 
135   MlirType i8 = mlirIntegerTypeGet(ctx, 8);
136   MlirType f32 = mlirF32TypeGet(ctx);
137   MlirType uniform = mlirUniformQuantizedTypeGet(
138       mlirQuantizedTypeGetSignedFlag(), i8, f32, 0.99872, 127, -8, 7);
139 
140   // CHECK: scale: 0.998720
141   fprintf(stderr, "scale: %lf\n", mlirUniformQuantizedTypeGetScale(uniform));
142   // CHECK: zero point: 127
143   fprintf(stderr, "zero point: %" PRId64 "\n",
144           mlirUniformQuantizedTypeGetZeroPoint(uniform));
145   // CHECK: fixed point: 0
146   fprintf(stderr, "fixed point: %d\n",
147           mlirUniformQuantizedTypeIsFixedPoint(uniform));
148 
149   // CHECK: equal: 1
150   fprintf(stderr, "equal: %d\n", mlirTypeEqual(uniform, uniformParsed));
151   // CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
152   mlirTypeDump(uniform);
153   fprintf(stderr, "\n\n");
154 }
155 
156 // CHECK-LABEL: testUniformPerAxisType
testUniformPerAxisType(MlirContext ctx)157 void testUniformPerAxisType(MlirContext ctx) {
158   fprintf(stderr, "testUniformPerAxisType\n");
159 
160   MlirType perAxisParsed = mlirTypeParseGet(
161       ctx, mlirStringRefCreateFromCString(
162                "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));
163 
164   MlirType i8 = mlirIntegerTypeGet(ctx, 8);
165   MlirType f32 = mlirF32TypeGet(ctx);
166   double scales[] = {200.0, 0.99872};
167   int64_t zeroPoints[] = {0, 120};
168   MlirType perAxis = mlirUniformQuantizedPerAxisTypeGet(
169       mlirQuantizedTypeGetSignedFlag(), i8, f32,
170       /*nDims=*/2, scales, zeroPoints,
171       /*quantizedDimension=*/1,
172       mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true,
173                                                    /*integralWidth=*/8),
174       mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true,
175                                                    /*integralWidth=*/8));
176 
177   // CHECK: num dims: 2
178   fprintf(stderr, "num dims: %" PRIdPTR "\n",
179           mlirUniformQuantizedPerAxisTypeGetNumDims(perAxis));
180   // CHECK: scale 0: 200.000000
181   fprintf(stderr, "scale 0: %lf\n",
182           mlirUniformQuantizedPerAxisTypeGetScale(perAxis, 0));
183   // CHECK: scale 1: 0.998720
184   fprintf(stderr, "scale 1: %lf\n",
185           mlirUniformQuantizedPerAxisTypeGetScale(perAxis, 1));
186   // CHECK: zero point 0: 0
187   fprintf(stderr, "zero point 0: %" PRId64 "\n",
188           mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis, 0));
189   // CHECK: zero point 1: 120
190   fprintf(stderr, "zero point 1: %" PRId64 "\n",
191           mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis, 1));
192   // CHECK: quantized dim: 1
193   fprintf(stderr, "quantized dim: %" PRId32 "\n",
194           mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(perAxis));
195   // CHECK: fixed point: 0
196   fprintf(stderr, "fixed point: %d\n",
197           mlirUniformQuantizedPerAxisTypeIsFixedPoint(perAxis));
198 
199   // CHECK: equal: 1
200   fprintf(stderr, "equal: %d\n", mlirTypeEqual(perAxis, perAxisParsed));
201   // CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
202   mlirTypeDump(perAxis);
203   fprintf(stderr, "\n\n");
204 }
205 
206 // CHECK-LABEL: testCalibratedType
testCalibratedType(MlirContext ctx)207 void testCalibratedType(MlirContext ctx) {
208   fprintf(stderr, "testCalibratedType\n");
209 
210   MlirType calibratedParsed = mlirTypeParseGet(
211       ctx,
212       mlirStringRefCreateFromCString("!quant.calibrated<f32<-0.998:1.2321>>"));
213 
214   MlirType f32 = mlirF32TypeGet(ctx);
215   MlirType calibrated = mlirCalibratedQuantizedTypeGet(f32, -0.998, 1.2321);
216 
217   // CHECK: min: -0.998000
218   fprintf(stderr, "min: %lf\n", mlirCalibratedQuantizedTypeGetMin(calibrated));
219   // CHECK: max: 1.232100
220   fprintf(stderr, "max: %lf\n", mlirCalibratedQuantizedTypeGetMax(calibrated));
221 
222   // CHECK: equal: 1
223   fprintf(stderr, "equal: %d\n", mlirTypeEqual(calibrated, calibratedParsed));
224   // CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
225   mlirTypeDump(calibrated);
226   fprintf(stderr, "\n\n");
227 }
228 
main(void)229 int main(void) {
230   MlirContext ctx = mlirContextCreate();
231   mlirDialectHandleRegisterDialect(mlirGetDialectHandle__quant__(), ctx);
232   testTypeHierarchy(ctx);
233   testAnyQuantizedType(ctx);
234   testUniformType(ctx);
235   testUniformPerAxisType(ctx);
236   testCalibratedType(ctx);
237   mlirContextDestroy(ctx);
238   return EXIT_SUCCESS;
239 }
240