1a8a2ee63SDenys Shabalin //===- Quant.cpp - C Interface for Quant dialect --------------------------===// 29bcf13bfSAlex Zinenko // 39bcf13bfSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 49bcf13bfSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 59bcf13bfSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 69bcf13bfSAlex Zinenko // 79bcf13bfSAlex Zinenko //===----------------------------------------------------------------------===// 89bcf13bfSAlex Zinenko 99bcf13bfSAlex Zinenko #include "mlir-c/Dialect/Quant.h" 109bcf13bfSAlex Zinenko #include "mlir/CAPI/Registration.h" 11*852b6486SRafael Ubal #include "mlir/Dialect/Quant/IR/Quant.h" 12*852b6486SRafael Ubal #include "mlir/Dialect/Quant/IR/QuantTypes.h" 139bcf13bfSAlex Zinenko 149bcf13bfSAlex Zinenko using namespace mlir; 159bcf13bfSAlex Zinenko 16*852b6486SRafael Ubal MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantDialect) 179bcf13bfSAlex Zinenko 189bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===// 199bcf13bfSAlex Zinenko // QuantizedType 209bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===// 219bcf13bfSAlex Zinenko 229bcf13bfSAlex Zinenko bool mlirTypeIsAQuantizedType(MlirType type) { 235550c821STres Popp return isa<quant::QuantizedType>(unwrap(type)); 249bcf13bfSAlex Zinenko } 259bcf13bfSAlex Zinenko 269bcf13bfSAlex Zinenko unsigned mlirQuantizedTypeGetSignedFlag() { 279bcf13bfSAlex Zinenko return quant::QuantizationFlags::Signed; 289bcf13bfSAlex Zinenko } 299bcf13bfSAlex Zinenko 309bcf13bfSAlex Zinenko int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned, 319bcf13bfSAlex Zinenko unsigned integralWidth) { 329bcf13bfSAlex Zinenko return quant::QuantizedType::getDefaultMinimumForInteger(isSigned, 339bcf13bfSAlex Zinenko integralWidth); 349bcf13bfSAlex Zinenko } 359bcf13bfSAlex Zinenko 369bcf13bfSAlex Zinenko int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned, 379bcf13bfSAlex Zinenko unsigned integralWidth) { 389bcf13bfSAlex Zinenko return quant::QuantizedType::getDefaultMaximumForInteger(isSigned, 399bcf13bfSAlex Zinenko integralWidth); 409bcf13bfSAlex Zinenko } 419bcf13bfSAlex Zinenko 429bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeGetExpressedType(MlirType type) { 435550c821STres Popp return wrap(cast<quant::QuantizedType>(unwrap(type)).getExpressedType()); 449bcf13bfSAlex Zinenko } 459bcf13bfSAlex Zinenko 469bcf13bfSAlex Zinenko unsigned mlirQuantizedTypeGetFlags(MlirType type) { 475550c821STres Popp return cast<quant::QuantizedType>(unwrap(type)).getFlags(); 489bcf13bfSAlex Zinenko } 499bcf13bfSAlex Zinenko 509bcf13bfSAlex Zinenko bool mlirQuantizedTypeIsSigned(MlirType type) { 515550c821STres Popp return cast<quant::QuantizedType>(unwrap(type)).isSigned(); 529bcf13bfSAlex Zinenko } 539bcf13bfSAlex Zinenko 549bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeGetStorageType(MlirType type) { 555550c821STres Popp return wrap(cast<quant::QuantizedType>(unwrap(type)).getStorageType()); 569bcf13bfSAlex Zinenko } 579bcf13bfSAlex Zinenko 589bcf13bfSAlex Zinenko int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) { 595550c821STres Popp return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeMin(); 609bcf13bfSAlex Zinenko } 619bcf13bfSAlex Zinenko 629bcf13bfSAlex Zinenko int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) { 635550c821STres Popp return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeMax(); 649bcf13bfSAlex Zinenko } 659bcf13bfSAlex Zinenko 669bcf13bfSAlex Zinenko unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) { 675550c821STres Popp return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeIntegralWidth(); 689bcf13bfSAlex Zinenko } 699bcf13bfSAlex Zinenko 709bcf13bfSAlex Zinenko bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type, 719bcf13bfSAlex Zinenko MlirType candidate) { 725550c821STres Popp return cast<quant::QuantizedType>(unwrap(type)) 735550c821STres Popp .isCompatibleExpressedType(unwrap(candidate)); 749bcf13bfSAlex Zinenko } 759bcf13bfSAlex Zinenko 769bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) { 779bcf13bfSAlex Zinenko return wrap(quant::QuantizedType::getQuantizedElementType(unwrap(type))); 789bcf13bfSAlex Zinenko } 799bcf13bfSAlex Zinenko 809bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeCastFromStorageType(MlirType type, 819bcf13bfSAlex Zinenko MlirType candidate) { 825550c821STres Popp return wrap(cast<quant::QuantizedType>(unwrap(type)) 835550c821STres Popp .castFromStorageType(unwrap(candidate))); 849bcf13bfSAlex Zinenko } 859bcf13bfSAlex Zinenko 869bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeCastToStorageType(MlirType type) { 879bcf13bfSAlex Zinenko return wrap(quant::QuantizedType::castToStorageType( 885550c821STres Popp cast<quant::QuantizedType>(unwrap(type)))); 899bcf13bfSAlex Zinenko } 909bcf13bfSAlex Zinenko 919bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type, 929bcf13bfSAlex Zinenko MlirType candidate) { 935550c821STres Popp return wrap(cast<quant::QuantizedType>(unwrap(type)) 945550c821STres Popp .castFromExpressedType(unwrap(candidate))); 959bcf13bfSAlex Zinenko } 969bcf13bfSAlex Zinenko 979bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) { 989bcf13bfSAlex Zinenko return wrap(quant::QuantizedType::castToExpressedType(unwrap(type))); 999bcf13bfSAlex Zinenko } 1009bcf13bfSAlex Zinenko 1019bcf13bfSAlex Zinenko MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type, 1029bcf13bfSAlex Zinenko MlirType candidate) { 1035550c821STres Popp return wrap(cast<quant::QuantizedType>(unwrap(type)) 1045550c821STres Popp .castExpressedToStorageType(unwrap(candidate))); 1059bcf13bfSAlex Zinenko } 1069bcf13bfSAlex Zinenko 1079bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===// 1089bcf13bfSAlex Zinenko // AnyQuantizedType 1099bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===// 1109bcf13bfSAlex Zinenko 1119bcf13bfSAlex Zinenko bool mlirTypeIsAAnyQuantizedType(MlirType type) { 1125550c821STres Popp return isa<quant::AnyQuantizedType>(unwrap(type)); 1139bcf13bfSAlex Zinenko } 1149bcf13bfSAlex Zinenko 1159bcf13bfSAlex Zinenko MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType, 1169bcf13bfSAlex Zinenko MlirType expressedType, int64_t storageTypeMin, 1179bcf13bfSAlex Zinenko int64_t storageTypeMax) { 1189bcf13bfSAlex Zinenko return wrap(quant::AnyQuantizedType::get(flags, unwrap(storageType), 1199bcf13bfSAlex Zinenko unwrap(expressedType), 1209bcf13bfSAlex Zinenko storageTypeMin, storageTypeMax)); 1219bcf13bfSAlex Zinenko } 1229bcf13bfSAlex Zinenko 1239bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===// 1249bcf13bfSAlex Zinenko // UniformQuantizedType 1259bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===// 1269bcf13bfSAlex Zinenko 1279bcf13bfSAlex Zinenko bool mlirTypeIsAUniformQuantizedType(MlirType type) { 1285550c821STres Popp return isa<quant::UniformQuantizedType>(unwrap(type)); 1299bcf13bfSAlex Zinenko } 1309bcf13bfSAlex Zinenko 1319bcf13bfSAlex Zinenko MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType, 1329bcf13bfSAlex Zinenko MlirType expressedType, double scale, 1339bcf13bfSAlex Zinenko int64_t zeroPoint, int64_t storageTypeMin, 1349bcf13bfSAlex Zinenko int64_t storageTypeMax) { 1359bcf13bfSAlex Zinenko return wrap(quant::UniformQuantizedType::get( 1369bcf13bfSAlex Zinenko flags, unwrap(storageType), unwrap(expressedType), scale, zeroPoint, 1379bcf13bfSAlex Zinenko storageTypeMin, storageTypeMax)); 1389bcf13bfSAlex Zinenko } 1399bcf13bfSAlex Zinenko 1409bcf13bfSAlex Zinenko double mlirUniformQuantizedTypeGetScale(MlirType type) { 1415550c821STres Popp return cast<quant::UniformQuantizedType>(unwrap(type)).getScale(); 1429bcf13bfSAlex Zinenko } 1439bcf13bfSAlex Zinenko 1449bcf13bfSAlex Zinenko int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) { 1455550c821STres Popp return cast<quant::UniformQuantizedType>(unwrap(type)).getZeroPoint(); 1469bcf13bfSAlex Zinenko } 1479bcf13bfSAlex Zinenko 1489bcf13bfSAlex Zinenko bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) { 1495550c821STres Popp return cast<quant::UniformQuantizedType>(unwrap(type)).isFixedPoint(); 1509bcf13bfSAlex Zinenko } 1519bcf13bfSAlex Zinenko 1529bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===// 1539bcf13bfSAlex Zinenko // UniformQuantizedPerAxisType 1549bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===// 1559bcf13bfSAlex Zinenko 1569bcf13bfSAlex Zinenko bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) { 1575550c821STres Popp return isa<quant::UniformQuantizedPerAxisType>(unwrap(type)); 1589bcf13bfSAlex Zinenko } 1599bcf13bfSAlex Zinenko 1609bcf13bfSAlex Zinenko MlirType mlirUniformQuantizedPerAxisTypeGet( 1619bcf13bfSAlex Zinenko unsigned flags, MlirType storageType, MlirType expressedType, 1629bcf13bfSAlex Zinenko intptr_t nDims, double *scales, int64_t *zeroPoints, 1639bcf13bfSAlex Zinenko int32_t quantizedDimension, int64_t storageTypeMin, 1649bcf13bfSAlex Zinenko int64_t storageTypeMax) { 1659bcf13bfSAlex Zinenko return wrap(quant::UniformQuantizedPerAxisType::get( 1669bcf13bfSAlex Zinenko flags, unwrap(storageType), unwrap(expressedType), 167984b800aSserge-sans-paille llvm::ArrayRef(scales, nDims), llvm::ArrayRef(zeroPoints, nDims), 1689bcf13bfSAlex Zinenko quantizedDimension, storageTypeMin, storageTypeMax)); 1699bcf13bfSAlex Zinenko } 1709bcf13bfSAlex Zinenko 1719bcf13bfSAlex Zinenko intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) { 1725550c821STres Popp return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)) 1739bcf13bfSAlex Zinenko .getScales() 1749bcf13bfSAlex Zinenko .size(); 1759bcf13bfSAlex Zinenko } 1769bcf13bfSAlex Zinenko 1779bcf13bfSAlex Zinenko double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) { 1785550c821STres Popp return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)) 1799bcf13bfSAlex Zinenko .getScales()[pos]; 1809bcf13bfSAlex Zinenko } 1819bcf13bfSAlex Zinenko 1829bcf13bfSAlex Zinenko int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type, 1839bcf13bfSAlex Zinenko intptr_t pos) { 1845550c821STres Popp return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)) 1859bcf13bfSAlex Zinenko .getZeroPoints()[pos]; 1869bcf13bfSAlex Zinenko } 1879bcf13bfSAlex Zinenko 1889bcf13bfSAlex Zinenko int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) { 1895550c821STres Popp return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)) 1909bcf13bfSAlex Zinenko .getQuantizedDimension(); 1919bcf13bfSAlex Zinenko } 1929bcf13bfSAlex Zinenko 1939bcf13bfSAlex Zinenko bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) { 1945550c821STres Popp return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)).isFixedPoint(); 1959bcf13bfSAlex Zinenko } 1969bcf13bfSAlex Zinenko 1979bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===// 1989bcf13bfSAlex Zinenko // CalibratedQuantizedType 1999bcf13bfSAlex Zinenko //===---------------------------------------------------------------------===// 2009bcf13bfSAlex Zinenko 2019bcf13bfSAlex Zinenko bool mlirTypeIsACalibratedQuantizedType(MlirType type) { 2025550c821STres Popp return isa<quant::CalibratedQuantizedType>(unwrap(type)); 2039bcf13bfSAlex Zinenko } 2049bcf13bfSAlex Zinenko 2059bcf13bfSAlex Zinenko MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, 2069bcf13bfSAlex Zinenko double max) { 2079bcf13bfSAlex Zinenko return wrap( 2089bcf13bfSAlex Zinenko quant::CalibratedQuantizedType::get(unwrap(expressedType), min, max)); 2099bcf13bfSAlex Zinenko } 2109bcf13bfSAlex Zinenko 2119bcf13bfSAlex Zinenko double mlirCalibratedQuantizedTypeGetMin(MlirType type) { 2125550c821STres Popp return cast<quant::CalibratedQuantizedType>(unwrap(type)).getMin(); 2139bcf13bfSAlex Zinenko } 2149bcf13bfSAlex Zinenko 2159bcf13bfSAlex Zinenko double mlirCalibratedQuantizedTypeGetMax(MlirType type) { 2165550c821STres Popp return cast<quant::CalibratedQuantizedType>(unwrap(type)).getMax(); 2179bcf13bfSAlex Zinenko } 218