xref: /llvm-project/mlir/test/CAPI/quant.c (revision 5d91f79fced13604ff401e5f5a6d5c3a9062ab20)
19bcf13bfSAlex Zinenko //===- quant.c - Test of Quant dialect C API ------------------------------===//
29bcf13bfSAlex Zinenko //
39bcf13bfSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM
49bcf13bfSAlex Zinenko // Exceptions.
59bcf13bfSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
69bcf13bfSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
79bcf13bfSAlex Zinenko //
89bcf13bfSAlex Zinenko //===----------------------------------------------------------------------===//
99bcf13bfSAlex Zinenko 
109bcf13bfSAlex Zinenko // RUN: mlir-capi-quant-test 2>&1 | FileCheck %s
119bcf13bfSAlex Zinenko 
129bcf13bfSAlex Zinenko #include "mlir-c/Dialect/Quant.h"
139bcf13bfSAlex Zinenko #include "mlir-c/BuiltinTypes.h"
149bcf13bfSAlex Zinenko #include "mlir-c/IR.h"
159bcf13bfSAlex Zinenko 
169bcf13bfSAlex Zinenko #include <assert.h>
179bcf13bfSAlex Zinenko #include <inttypes.h>
189bcf13bfSAlex Zinenko #include <stdio.h>
199bcf13bfSAlex Zinenko #include <stdlib.h>
209bcf13bfSAlex Zinenko 
219bcf13bfSAlex Zinenko // CHECK-LABEL: testTypeHierarchy
testTypeHierarchy(MlirContext ctx)229bcf13bfSAlex Zinenko static void testTypeHierarchy(MlirContext ctx) {
239bcf13bfSAlex Zinenko   fprintf(stderr, "testTypeHierarchy\n");
249bcf13bfSAlex Zinenko 
259bcf13bfSAlex Zinenko   MlirType i8 = mlirIntegerTypeGet(ctx, 8);
269bcf13bfSAlex Zinenko   MlirType any = mlirTypeParseGet(
279bcf13bfSAlex Zinenko       ctx, mlirStringRefCreateFromCString("!quant.any<i8<-8:7>:f32>"));
289bcf13bfSAlex Zinenko   MlirType uniform =
299bcf13bfSAlex Zinenko       mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(
309bcf13bfSAlex Zinenko                                 "!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));
319bcf13bfSAlex Zinenko   MlirType perAxis = mlirTypeParseGet(
329bcf13bfSAlex Zinenko       ctx, mlirStringRefCreateFromCString(
339bcf13bfSAlex Zinenko                "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));
349bcf13bfSAlex Zinenko   MlirType calibrated = mlirTypeParseGet(
359bcf13bfSAlex Zinenko       ctx,
369bcf13bfSAlex Zinenko       mlirStringRefCreateFromCString("!quant.calibrated<f32<-0.998:1.2321>>"));
379bcf13bfSAlex Zinenko 
389bcf13bfSAlex Zinenko   // The parser itself is checked in C++ dialect tests.
399bcf13bfSAlex Zinenko   assert(!mlirTypeIsNull(any) && "couldn't parse AnyQuantizedType");
409bcf13bfSAlex Zinenko   assert(!mlirTypeIsNull(uniform) && "couldn't parse UniformQuantizedType");
419bcf13bfSAlex Zinenko   assert(!mlirTypeIsNull(perAxis) &&
429bcf13bfSAlex Zinenko          "couldn't parse UniformQuantizedPerAxisType");
439bcf13bfSAlex Zinenko   assert(!mlirTypeIsNull(calibrated) &&
449bcf13bfSAlex Zinenko          "couldn't parse CalibratedQuantizedType");
459bcf13bfSAlex Zinenko 
469bcf13bfSAlex Zinenko   // CHECK: i8 isa QuantizedType: 0
479bcf13bfSAlex Zinenko   fprintf(stderr, "i8 isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(i8));
489bcf13bfSAlex Zinenko   // CHECK: any isa QuantizedType: 1
499bcf13bfSAlex Zinenko   fprintf(stderr, "any isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(any));
509bcf13bfSAlex Zinenko   // CHECK: uniform isa QuantizedType: 1
519bcf13bfSAlex Zinenko   fprintf(stderr, "uniform isa QuantizedType: %d\n",
529bcf13bfSAlex Zinenko           mlirTypeIsAQuantizedType(uniform));
539bcf13bfSAlex Zinenko   // CHECK: perAxis isa QuantizedType: 1
549bcf13bfSAlex Zinenko   fprintf(stderr, "perAxis isa QuantizedType: %d\n",
559bcf13bfSAlex Zinenko           mlirTypeIsAQuantizedType(perAxis));
569bcf13bfSAlex Zinenko   // CHECK: calibrated isa QuantizedType: 1
579bcf13bfSAlex Zinenko   fprintf(stderr, "calibrated isa QuantizedType: %d\n",
589bcf13bfSAlex Zinenko           mlirTypeIsAQuantizedType(calibrated));
599bcf13bfSAlex Zinenko 
609bcf13bfSAlex Zinenko   // CHECK: any isa AnyQuantizedType: 1
619bcf13bfSAlex Zinenko   fprintf(stderr, "any isa AnyQuantizedType: %d\n",
629bcf13bfSAlex Zinenko           mlirTypeIsAAnyQuantizedType(any));
639bcf13bfSAlex Zinenko   // CHECK: uniform isa UniformQuantizedType: 1
649bcf13bfSAlex Zinenko   fprintf(stderr, "uniform isa UniformQuantizedType: %d\n",
659bcf13bfSAlex Zinenko           mlirTypeIsAUniformQuantizedType(uniform));
669bcf13bfSAlex Zinenko   // CHECK: perAxis isa UniformQuantizedPerAxisType: 1
679bcf13bfSAlex Zinenko   fprintf(stderr, "perAxis isa UniformQuantizedPerAxisType: %d\n",
689bcf13bfSAlex Zinenko           mlirTypeIsAUniformQuantizedPerAxisType(perAxis));
699bcf13bfSAlex Zinenko   // CHECK: calibrated isa CalibratedQuantizedType: 1
709bcf13bfSAlex Zinenko   fprintf(stderr, "calibrated isa CalibratedQuantizedType: %d\n",
719bcf13bfSAlex Zinenko           mlirTypeIsACalibratedQuantizedType(calibrated));
729bcf13bfSAlex Zinenko 
739bcf13bfSAlex Zinenko   // CHECK: perAxis isa UniformQuantizedType: 0
749bcf13bfSAlex Zinenko   fprintf(stderr, "perAxis isa UniformQuantizedType: %d\n",
759bcf13bfSAlex Zinenko           mlirTypeIsAUniformQuantizedType(perAxis));
769bcf13bfSAlex Zinenko   // CHECK: uniform isa CalibratedQuantizedType: 0
779bcf13bfSAlex Zinenko   fprintf(stderr, "uniform isa CalibratedQuantizedType: %d\n",
789bcf13bfSAlex Zinenko           mlirTypeIsACalibratedQuantizedType(uniform));
799bcf13bfSAlex Zinenko   fprintf(stderr, "\n");
809bcf13bfSAlex Zinenko }
819bcf13bfSAlex Zinenko 
829bcf13bfSAlex Zinenko // CHECK-LABEL: testAnyQuantizedType
testAnyQuantizedType(MlirContext ctx)839bcf13bfSAlex Zinenko void testAnyQuantizedType(MlirContext ctx) {
849bcf13bfSAlex Zinenko   fprintf(stderr, "testAnyQuantizedType\n");
859bcf13bfSAlex Zinenko 
869bcf13bfSAlex Zinenko   MlirType anyParsed = mlirTypeParseGet(
879bcf13bfSAlex Zinenko       ctx, mlirStringRefCreateFromCString("!quant.any<i8<-8:7>:f32>"));
889bcf13bfSAlex Zinenko 
899bcf13bfSAlex Zinenko   MlirType i8 = mlirIntegerTypeGet(ctx, 8);
909bcf13bfSAlex Zinenko   MlirType f32 = mlirF32TypeGet(ctx);
919bcf13bfSAlex Zinenko   MlirType any =
929bcf13bfSAlex Zinenko       mlirAnyQuantizedTypeGet(mlirQuantizedTypeGetSignedFlag(), i8, f32, -8, 7);
939bcf13bfSAlex Zinenko 
949bcf13bfSAlex Zinenko   // CHECK: flags: 1
959bcf13bfSAlex Zinenko   fprintf(stderr, "flags: %u\n", mlirQuantizedTypeGetFlags(any));
969bcf13bfSAlex Zinenko   // CHECK: signed: 1
979bcf13bfSAlex Zinenko   fprintf(stderr, "signed: %u\n", mlirQuantizedTypeIsSigned(any));
989bcf13bfSAlex Zinenko   // CHECK: storage type: i8
999bcf13bfSAlex Zinenko   fprintf(stderr, "storage type: ");
1009bcf13bfSAlex Zinenko   mlirTypeDump(mlirQuantizedTypeGetStorageType(any));
1019bcf13bfSAlex Zinenko   fprintf(stderr, "\n");
1029bcf13bfSAlex Zinenko   // CHECK: expressed type: f32
1039bcf13bfSAlex Zinenko   fprintf(stderr, "expressed type: ");
1049bcf13bfSAlex Zinenko   mlirTypeDump(mlirQuantizedTypeGetExpressedType(any));
1059bcf13bfSAlex Zinenko   fprintf(stderr, "\n");
1069bcf13bfSAlex Zinenko   // CHECK: storage min: -8
1079bcf13bfSAlex Zinenko   fprintf(stderr, "storage min: %" PRId64 "\n",
1089bcf13bfSAlex Zinenko           mlirQuantizedTypeGetStorageTypeMin(any));
1099bcf13bfSAlex Zinenko   // CHECK: storage max: 7
1109bcf13bfSAlex Zinenko   fprintf(stderr, "storage max: %" PRId64 "\n",
1119bcf13bfSAlex Zinenko           mlirQuantizedTypeGetStorageTypeMax(any));
1129bcf13bfSAlex Zinenko   // CHECK: storage width: 8
1139bcf13bfSAlex Zinenko   fprintf(stderr, "storage width: %u\n",
1149bcf13bfSAlex Zinenko           mlirQuantizedTypeGetStorageTypeIntegralWidth(any));
1159bcf13bfSAlex Zinenko   // CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
1169bcf13bfSAlex Zinenko   fprintf(stderr, "quantized element type: ");
1179bcf13bfSAlex Zinenko   mlirTypeDump(mlirQuantizedTypeGetQuantizedElementType(any));
1189bcf13bfSAlex Zinenko   fprintf(stderr, "\n");
1199bcf13bfSAlex Zinenko 
1209bcf13bfSAlex Zinenko   // CHECK: equal: 1
1219bcf13bfSAlex Zinenko   fprintf(stderr, "equal: %d\n", mlirTypeEqual(anyParsed, any));
1229bcf13bfSAlex Zinenko   // CHECK: !quant.any<i8<-8:7>:f32>
1239bcf13bfSAlex Zinenko   mlirTypeDump(any);
1249bcf13bfSAlex Zinenko   fprintf(stderr, "\n\n");
1259bcf13bfSAlex Zinenko }
1269bcf13bfSAlex Zinenko 
1279bcf13bfSAlex Zinenko // CHECK-LABEL: testUniformType
testUniformType(MlirContext ctx)1289bcf13bfSAlex Zinenko void testUniformType(MlirContext ctx) {
1299bcf13bfSAlex Zinenko   fprintf(stderr, "testUniformType\n");
1309bcf13bfSAlex Zinenko 
1319bcf13bfSAlex Zinenko   MlirType uniformParsed =
1329bcf13bfSAlex Zinenko       mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(
1339bcf13bfSAlex Zinenko                                 "!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));
1349bcf13bfSAlex Zinenko 
1359bcf13bfSAlex Zinenko   MlirType i8 = mlirIntegerTypeGet(ctx, 8);
1369bcf13bfSAlex Zinenko   MlirType f32 = mlirF32TypeGet(ctx);
1379bcf13bfSAlex Zinenko   MlirType uniform = mlirUniformQuantizedTypeGet(
1389bcf13bfSAlex Zinenko       mlirQuantizedTypeGetSignedFlag(), i8, f32, 0.99872, 127, -8, 7);
1399bcf13bfSAlex Zinenko 
1409bcf13bfSAlex Zinenko   // CHECK: scale: 0.998720
1419bcf13bfSAlex Zinenko   fprintf(stderr, "scale: %lf\n", mlirUniformQuantizedTypeGetScale(uniform));
1429bcf13bfSAlex Zinenko   // CHECK: zero point: 127
1439bcf13bfSAlex Zinenko   fprintf(stderr, "zero point: %" PRId64 "\n",
1449bcf13bfSAlex Zinenko           mlirUniformQuantizedTypeGetZeroPoint(uniform));
1459bcf13bfSAlex Zinenko   // CHECK: fixed point: 0
1469bcf13bfSAlex Zinenko   fprintf(stderr, "fixed point: %d\n",
1479bcf13bfSAlex Zinenko           mlirUniformQuantizedTypeIsFixedPoint(uniform));
1489bcf13bfSAlex Zinenko 
1499bcf13bfSAlex Zinenko   // CHECK: equal: 1
1509bcf13bfSAlex Zinenko   fprintf(stderr, "equal: %d\n", mlirTypeEqual(uniform, uniformParsed));
1519bcf13bfSAlex Zinenko   // CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
1529bcf13bfSAlex Zinenko   mlirTypeDump(uniform);
1539bcf13bfSAlex Zinenko   fprintf(stderr, "\n\n");
1549bcf13bfSAlex Zinenko }
1559bcf13bfSAlex Zinenko 
1569bcf13bfSAlex Zinenko // CHECK-LABEL: testUniformPerAxisType
testUniformPerAxisType(MlirContext ctx)1579bcf13bfSAlex Zinenko void testUniformPerAxisType(MlirContext ctx) {
1589bcf13bfSAlex Zinenko   fprintf(stderr, "testUniformPerAxisType\n");
1599bcf13bfSAlex Zinenko 
1609bcf13bfSAlex Zinenko   MlirType perAxisParsed = mlirTypeParseGet(
1619bcf13bfSAlex Zinenko       ctx, mlirStringRefCreateFromCString(
1629bcf13bfSAlex Zinenko                "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));
1639bcf13bfSAlex Zinenko 
1649bcf13bfSAlex Zinenko   MlirType i8 = mlirIntegerTypeGet(ctx, 8);
1659bcf13bfSAlex Zinenko   MlirType f32 = mlirF32TypeGet(ctx);
1669bcf13bfSAlex Zinenko   double scales[] = {200.0, 0.99872};
1679bcf13bfSAlex Zinenko   int64_t zeroPoints[] = {0, 120};
1689bcf13bfSAlex Zinenko   MlirType perAxis = mlirUniformQuantizedPerAxisTypeGet(
1699bcf13bfSAlex Zinenko       mlirQuantizedTypeGetSignedFlag(), i8, f32,
1709bcf13bfSAlex Zinenko       /*nDims=*/2, scales, zeroPoints,
1719bcf13bfSAlex Zinenko       /*quantizedDimension=*/1,
1729bcf13bfSAlex Zinenko       mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true,
1739bcf13bfSAlex Zinenko                                                    /*integralWidth=*/8),
1749bcf13bfSAlex Zinenko       mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true,
1759bcf13bfSAlex Zinenko                                                    /*integralWidth=*/8));
1769bcf13bfSAlex Zinenko 
1779bcf13bfSAlex Zinenko   // CHECK: num dims: 2
1789bcf13bfSAlex Zinenko   fprintf(stderr, "num dims: %" PRIdPTR "\n",
1799bcf13bfSAlex Zinenko           mlirUniformQuantizedPerAxisTypeGetNumDims(perAxis));
1809bcf13bfSAlex Zinenko   // CHECK: scale 0: 200.000000
1819bcf13bfSAlex Zinenko   fprintf(stderr, "scale 0: %lf\n",
1829bcf13bfSAlex Zinenko           mlirUniformQuantizedPerAxisTypeGetScale(perAxis, 0));
1839bcf13bfSAlex Zinenko   // CHECK: scale 1: 0.998720
1849bcf13bfSAlex Zinenko   fprintf(stderr, "scale 1: %lf\n",
1859bcf13bfSAlex Zinenko           mlirUniformQuantizedPerAxisTypeGetScale(perAxis, 1));
1869bcf13bfSAlex Zinenko   // CHECK: zero point 0: 0
1879bcf13bfSAlex Zinenko   fprintf(stderr, "zero point 0: %" PRId64 "\n",
1889bcf13bfSAlex Zinenko           mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis, 0));
1899bcf13bfSAlex Zinenko   // CHECK: zero point 1: 120
1909bcf13bfSAlex Zinenko   fprintf(stderr, "zero point 1: %" PRId64 "\n",
1919bcf13bfSAlex Zinenko           mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis, 1));
1929bcf13bfSAlex Zinenko   // CHECK: quantized dim: 1
1939bcf13bfSAlex Zinenko   fprintf(stderr, "quantized dim: %" PRId32 "\n",
1949bcf13bfSAlex Zinenko           mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(perAxis));
1959bcf13bfSAlex Zinenko   // CHECK: fixed point: 0
1969bcf13bfSAlex Zinenko   fprintf(stderr, "fixed point: %d\n",
1979bcf13bfSAlex Zinenko           mlirUniformQuantizedPerAxisTypeIsFixedPoint(perAxis));
1989bcf13bfSAlex Zinenko 
1999bcf13bfSAlex Zinenko   // CHECK: equal: 1
2009bcf13bfSAlex Zinenko   fprintf(stderr, "equal: %d\n", mlirTypeEqual(perAxis, perAxisParsed));
2019bcf13bfSAlex Zinenko   // CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
2029bcf13bfSAlex Zinenko   mlirTypeDump(perAxis);
2039bcf13bfSAlex Zinenko   fprintf(stderr, "\n\n");
2049bcf13bfSAlex Zinenko }
2059bcf13bfSAlex Zinenko 
2069bcf13bfSAlex Zinenko // CHECK-LABEL: testCalibratedType
testCalibratedType(MlirContext ctx)2079bcf13bfSAlex Zinenko void testCalibratedType(MlirContext ctx) {
2089bcf13bfSAlex Zinenko   fprintf(stderr, "testCalibratedType\n");
2099bcf13bfSAlex Zinenko 
2109bcf13bfSAlex Zinenko   MlirType calibratedParsed = mlirTypeParseGet(
2119bcf13bfSAlex Zinenko       ctx,
2129bcf13bfSAlex Zinenko       mlirStringRefCreateFromCString("!quant.calibrated<f32<-0.998:1.2321>>"));
2139bcf13bfSAlex Zinenko 
2149bcf13bfSAlex Zinenko   MlirType f32 = mlirF32TypeGet(ctx);
2159bcf13bfSAlex Zinenko   MlirType calibrated = mlirCalibratedQuantizedTypeGet(f32, -0.998, 1.2321);
2169bcf13bfSAlex Zinenko 
2179bcf13bfSAlex Zinenko   // CHECK: min: -0.998000
2189bcf13bfSAlex Zinenko   fprintf(stderr, "min: %lf\n", mlirCalibratedQuantizedTypeGetMin(calibrated));
2199bcf13bfSAlex Zinenko   // CHECK: max: 1.232100
2209bcf13bfSAlex Zinenko   fprintf(stderr, "max: %lf\n", mlirCalibratedQuantizedTypeGetMax(calibrated));
2219bcf13bfSAlex Zinenko 
2229bcf13bfSAlex Zinenko   // CHECK: equal: 1
2239bcf13bfSAlex Zinenko   fprintf(stderr, "equal: %d\n", mlirTypeEqual(calibrated, calibratedParsed));
2249bcf13bfSAlex Zinenko   // CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
2259bcf13bfSAlex Zinenko   mlirTypeDump(calibrated);
2269bcf13bfSAlex Zinenko   fprintf(stderr, "\n\n");
2279bcf13bfSAlex Zinenko }
2289bcf13bfSAlex Zinenko 
main(void)229*5d91f79fSTom Eccles int main(void) {
2309bcf13bfSAlex Zinenko   MlirContext ctx = mlirContextCreate();
2319bcf13bfSAlex Zinenko   mlirDialectHandleRegisterDialect(mlirGetDialectHandle__quant__(), ctx);
2329bcf13bfSAlex Zinenko   testTypeHierarchy(ctx);
2339bcf13bfSAlex Zinenko   testAnyQuantizedType(ctx);
2349bcf13bfSAlex Zinenko   testUniformType(ctx);
2359bcf13bfSAlex Zinenko   testUniformPerAxisType(ctx);
2369bcf13bfSAlex Zinenko   testCalibratedType(ctx);
2379bcf13bfSAlex Zinenko   mlirContextDestroy(ctx);
2389bcf13bfSAlex Zinenko   return EXIT_SUCCESS;
2399bcf13bfSAlex Zinenko }
240