xref: /llvm-project/mlir/lib/CAPI/Dialect/Quant.cpp (revision 852b6486246141e44cc9f126f542a2ae0d73b3d6)
1a8a2ee63SDenys Shabalin //===- Quant.cpp - C Interface for Quant dialect --------------------------===//
29bcf13bfSAlex Zinenko //
39bcf13bfSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49bcf13bfSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
59bcf13bfSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69bcf13bfSAlex Zinenko //
79bcf13bfSAlex Zinenko //===----------------------------------------------------------------------===//
89bcf13bfSAlex Zinenko 
99bcf13bfSAlex Zinenko #include "mlir-c/Dialect/Quant.h"
109bcf13bfSAlex Zinenko #include "mlir/CAPI/Registration.h"
11*852b6486SRafael Ubal #include "mlir/Dialect/Quant/IR/Quant.h"
12*852b6486SRafael Ubal #include "mlir/Dialect/Quant/IR/QuantTypes.h"
139bcf13bfSAlex Zinenko 
149bcf13bfSAlex Zinenko using namespace mlir;
159bcf13bfSAlex Zinenko 
16*852b6486SRafael Ubal MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantDialect)
179bcf13bfSAlex Zinenko 
189bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
199bcf13bfSAlex Zinenko // QuantizedType
209bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
219bcf13bfSAlex Zinenko 
229bcf13bfSAlex Zinenko bool mlirTypeIsAQuantizedType(MlirType type) {
235550c821STres Popp   return isa<quant::QuantizedType>(unwrap(type));
249bcf13bfSAlex Zinenko }
259bcf13bfSAlex Zinenko 
269bcf13bfSAlex Zinenko unsigned mlirQuantizedTypeGetSignedFlag() {
279bcf13bfSAlex Zinenko   return quant::QuantizationFlags::Signed;
289bcf13bfSAlex Zinenko }
299bcf13bfSAlex Zinenko 
309bcf13bfSAlex Zinenko int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned,
319bcf13bfSAlex Zinenko                                                      unsigned integralWidth) {
329bcf13bfSAlex Zinenko   return quant::QuantizedType::getDefaultMinimumForInteger(isSigned,
339bcf13bfSAlex Zinenko                                                            integralWidth);
349bcf13bfSAlex Zinenko }
359bcf13bfSAlex Zinenko 
369bcf13bfSAlex Zinenko int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned,
379bcf13bfSAlex Zinenko                                                      unsigned integralWidth) {
389bcf13bfSAlex Zinenko   return quant::QuantizedType::getDefaultMaximumForInteger(isSigned,
399bcf13bfSAlex Zinenko                                                            integralWidth);
409bcf13bfSAlex Zinenko }
419bcf13bfSAlex Zinenko 
429bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeGetExpressedType(MlirType type) {
435550c821STres Popp   return wrap(cast<quant::QuantizedType>(unwrap(type)).getExpressedType());
449bcf13bfSAlex Zinenko }
459bcf13bfSAlex Zinenko 
469bcf13bfSAlex Zinenko unsigned mlirQuantizedTypeGetFlags(MlirType type) {
475550c821STres Popp   return cast<quant::QuantizedType>(unwrap(type)).getFlags();
489bcf13bfSAlex Zinenko }
499bcf13bfSAlex Zinenko 
509bcf13bfSAlex Zinenko bool mlirQuantizedTypeIsSigned(MlirType type) {
515550c821STres Popp   return cast<quant::QuantizedType>(unwrap(type)).isSigned();
529bcf13bfSAlex Zinenko }
539bcf13bfSAlex Zinenko 
549bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeGetStorageType(MlirType type) {
555550c821STres Popp   return wrap(cast<quant::QuantizedType>(unwrap(type)).getStorageType());
569bcf13bfSAlex Zinenko }
579bcf13bfSAlex Zinenko 
589bcf13bfSAlex Zinenko int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) {
595550c821STres Popp   return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeMin();
609bcf13bfSAlex Zinenko }
619bcf13bfSAlex Zinenko 
629bcf13bfSAlex Zinenko int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) {
635550c821STres Popp   return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeMax();
649bcf13bfSAlex Zinenko }
659bcf13bfSAlex Zinenko 
669bcf13bfSAlex Zinenko unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) {
675550c821STres Popp   return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeIntegralWidth();
689bcf13bfSAlex Zinenko }
699bcf13bfSAlex Zinenko 
709bcf13bfSAlex Zinenko bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type,
719bcf13bfSAlex Zinenko                                                 MlirType candidate) {
725550c821STres Popp   return cast<quant::QuantizedType>(unwrap(type))
735550c821STres Popp       .isCompatibleExpressedType(unwrap(candidate));
749bcf13bfSAlex Zinenko }
759bcf13bfSAlex Zinenko 
769bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) {
779bcf13bfSAlex Zinenko   return wrap(quant::QuantizedType::getQuantizedElementType(unwrap(type)));
789bcf13bfSAlex Zinenko }
799bcf13bfSAlex Zinenko 
809bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeCastFromStorageType(MlirType type,
819bcf13bfSAlex Zinenko                                               MlirType candidate) {
825550c821STres Popp   return wrap(cast<quant::QuantizedType>(unwrap(type))
835550c821STres Popp                   .castFromStorageType(unwrap(candidate)));
849bcf13bfSAlex Zinenko }
859bcf13bfSAlex Zinenko 
869bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeCastToStorageType(MlirType type) {
879bcf13bfSAlex Zinenko   return wrap(quant::QuantizedType::castToStorageType(
885550c821STres Popp       cast<quant::QuantizedType>(unwrap(type))));
899bcf13bfSAlex Zinenko }
909bcf13bfSAlex Zinenko 
919bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type,
929bcf13bfSAlex Zinenko                                                 MlirType candidate) {
935550c821STres Popp   return wrap(cast<quant::QuantizedType>(unwrap(type))
945550c821STres Popp                   .castFromExpressedType(unwrap(candidate)));
959bcf13bfSAlex Zinenko }
969bcf13bfSAlex Zinenko 
979bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) {
989bcf13bfSAlex Zinenko   return wrap(quant::QuantizedType::castToExpressedType(unwrap(type)));
999bcf13bfSAlex Zinenko }
1009bcf13bfSAlex Zinenko 
1019bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type,
1029bcf13bfSAlex Zinenko                                                      MlirType candidate) {
1035550c821STres Popp   return wrap(cast<quant::QuantizedType>(unwrap(type))
1045550c821STres Popp                   .castExpressedToStorageType(unwrap(candidate)));
1059bcf13bfSAlex Zinenko }
1069bcf13bfSAlex Zinenko 
1079bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
1089bcf13bfSAlex Zinenko // AnyQuantizedType
1099bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
1109bcf13bfSAlex Zinenko 
1119bcf13bfSAlex Zinenko bool mlirTypeIsAAnyQuantizedType(MlirType type) {
1125550c821STres Popp   return isa<quant::AnyQuantizedType>(unwrap(type));
1139bcf13bfSAlex Zinenko }
1149bcf13bfSAlex Zinenko 
1159bcf13bfSAlex Zinenko MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType,
1169bcf13bfSAlex Zinenko                                  MlirType expressedType, int64_t storageTypeMin,
1179bcf13bfSAlex Zinenko                                  int64_t storageTypeMax) {
1189bcf13bfSAlex Zinenko   return wrap(quant::AnyQuantizedType::get(flags, unwrap(storageType),
1199bcf13bfSAlex Zinenko                                            unwrap(expressedType),
1209bcf13bfSAlex Zinenko                                            storageTypeMin, storageTypeMax));
1219bcf13bfSAlex Zinenko }
1229bcf13bfSAlex Zinenko 
1239bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
1249bcf13bfSAlex Zinenko // UniformQuantizedType
1259bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
1269bcf13bfSAlex Zinenko 
1279bcf13bfSAlex Zinenko bool mlirTypeIsAUniformQuantizedType(MlirType type) {
1285550c821STres Popp   return isa<quant::UniformQuantizedType>(unwrap(type));
1299bcf13bfSAlex Zinenko }
1309bcf13bfSAlex Zinenko 
1319bcf13bfSAlex Zinenko MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType,
1329bcf13bfSAlex Zinenko                                      MlirType expressedType, double scale,
1339bcf13bfSAlex Zinenko                                      int64_t zeroPoint, int64_t storageTypeMin,
1349bcf13bfSAlex Zinenko                                      int64_t storageTypeMax) {
1359bcf13bfSAlex Zinenko   return wrap(quant::UniformQuantizedType::get(
1369bcf13bfSAlex Zinenko       flags, unwrap(storageType), unwrap(expressedType), scale, zeroPoint,
1379bcf13bfSAlex Zinenko       storageTypeMin, storageTypeMax));
1389bcf13bfSAlex Zinenko }
1399bcf13bfSAlex Zinenko 
1409bcf13bfSAlex Zinenko double mlirUniformQuantizedTypeGetScale(MlirType type) {
1415550c821STres Popp   return cast<quant::UniformQuantizedType>(unwrap(type)).getScale();
1429bcf13bfSAlex Zinenko }
1439bcf13bfSAlex Zinenko 
1449bcf13bfSAlex Zinenko int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) {
1455550c821STres Popp   return cast<quant::UniformQuantizedType>(unwrap(type)).getZeroPoint();
1469bcf13bfSAlex Zinenko }
1479bcf13bfSAlex Zinenko 
1489bcf13bfSAlex Zinenko bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) {
1495550c821STres Popp   return cast<quant::UniformQuantizedType>(unwrap(type)).isFixedPoint();
1509bcf13bfSAlex Zinenko }
1519bcf13bfSAlex Zinenko 
1529bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
1539bcf13bfSAlex Zinenko // UniformQuantizedPerAxisType
1549bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
1559bcf13bfSAlex Zinenko 
1569bcf13bfSAlex Zinenko bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) {
1575550c821STres Popp   return isa<quant::UniformQuantizedPerAxisType>(unwrap(type));
1589bcf13bfSAlex Zinenko }
1599bcf13bfSAlex Zinenko 
1609bcf13bfSAlex Zinenko MlirType mlirUniformQuantizedPerAxisTypeGet(
1619bcf13bfSAlex Zinenko     unsigned flags, MlirType storageType, MlirType expressedType,
1629bcf13bfSAlex Zinenko     intptr_t nDims, double *scales, int64_t *zeroPoints,
1639bcf13bfSAlex Zinenko     int32_t quantizedDimension, int64_t storageTypeMin,
1649bcf13bfSAlex Zinenko     int64_t storageTypeMax) {
1659bcf13bfSAlex Zinenko   return wrap(quant::UniformQuantizedPerAxisType::get(
1669bcf13bfSAlex Zinenko       flags, unwrap(storageType), unwrap(expressedType),
167984b800aSserge-sans-paille       llvm::ArrayRef(scales, nDims), llvm::ArrayRef(zeroPoints, nDims),
1689bcf13bfSAlex Zinenko       quantizedDimension, storageTypeMin, storageTypeMax));
1699bcf13bfSAlex Zinenko }
1709bcf13bfSAlex Zinenko 
1719bcf13bfSAlex Zinenko intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) {
1725550c821STres Popp   return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
1739bcf13bfSAlex Zinenko       .getScales()
1749bcf13bfSAlex Zinenko       .size();
1759bcf13bfSAlex Zinenko }
1769bcf13bfSAlex Zinenko 
1779bcf13bfSAlex Zinenko double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) {
1785550c821STres Popp   return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
1799bcf13bfSAlex Zinenko       .getScales()[pos];
1809bcf13bfSAlex Zinenko }
1819bcf13bfSAlex Zinenko 
1829bcf13bfSAlex Zinenko int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type,
1839bcf13bfSAlex Zinenko                                                     intptr_t pos) {
1845550c821STres Popp   return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
1859bcf13bfSAlex Zinenko       .getZeroPoints()[pos];
1869bcf13bfSAlex Zinenko }
1879bcf13bfSAlex Zinenko 
1889bcf13bfSAlex Zinenko int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) {
1895550c821STres Popp   return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
1909bcf13bfSAlex Zinenko       .getQuantizedDimension();
1919bcf13bfSAlex Zinenko }
1929bcf13bfSAlex Zinenko 
1939bcf13bfSAlex Zinenko bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) {
1945550c821STres Popp   return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)).isFixedPoint();
1959bcf13bfSAlex Zinenko }
1969bcf13bfSAlex Zinenko 
1979bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
1989bcf13bfSAlex Zinenko // CalibratedQuantizedType
1999bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===//
2009bcf13bfSAlex Zinenko 
2019bcf13bfSAlex Zinenko bool mlirTypeIsACalibratedQuantizedType(MlirType type) {
2025550c821STres Popp   return isa<quant::CalibratedQuantizedType>(unwrap(type));
2039bcf13bfSAlex Zinenko }
2049bcf13bfSAlex Zinenko 
2059bcf13bfSAlex Zinenko MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min,
2069bcf13bfSAlex Zinenko                                         double max) {
2079bcf13bfSAlex Zinenko   return wrap(
2089bcf13bfSAlex Zinenko       quant::CalibratedQuantizedType::get(unwrap(expressedType), min, max));
2099bcf13bfSAlex Zinenko }
2109bcf13bfSAlex Zinenko 
2119bcf13bfSAlex Zinenko double mlirCalibratedQuantizedTypeGetMin(MlirType type) {
2125550c821STres Popp   return cast<quant::CalibratedQuantizedType>(unwrap(type)).getMin();
2139bcf13bfSAlex Zinenko }
2149bcf13bfSAlex Zinenko 
2159bcf13bfSAlex Zinenko double mlirCalibratedQuantizedTypeGetMax(MlirType type) {
2165550c821STres Popp   return cast<quant::CalibratedQuantizedType>(unwrap(type)).getMax();
2179bcf13bfSAlex Zinenko }
218