1 //===- Quant.cpp - C Interface for Quant dialect --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Dialect/Quant.h" 10 #include "mlir/CAPI/Registration.h" 11 #include "mlir/Dialect/Quant/IR/Quant.h" 12 #include "mlir/Dialect/Quant/IR/QuantTypes.h" 13 14 using namespace mlir; 15 16 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantDialect) 17 18 //===---------------------------------------------------------------------===// 19 // QuantizedType 20 //===---------------------------------------------------------------------===// 21 22 bool mlirTypeIsAQuantizedType(MlirType type) { 23 return isa<quant::QuantizedType>(unwrap(type)); 24 } 25 26 unsigned mlirQuantizedTypeGetSignedFlag() { 27 return quant::QuantizationFlags::Signed; 28 } 29 30 int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned, 31 unsigned integralWidth) { 32 return quant::QuantizedType::getDefaultMinimumForInteger(isSigned, 33 integralWidth); 34 } 35 36 int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned, 37 unsigned integralWidth) { 38 return quant::QuantizedType::getDefaultMaximumForInteger(isSigned, 39 integralWidth); 40 } 41 42 MlirType mlirQuantizedTypeGetExpressedType(MlirType type) { 43 return wrap(cast<quant::QuantizedType>(unwrap(type)).getExpressedType()); 44 } 45 46 unsigned mlirQuantizedTypeGetFlags(MlirType type) { 47 return cast<quant::QuantizedType>(unwrap(type)).getFlags(); 48 } 49 50 bool mlirQuantizedTypeIsSigned(MlirType type) { 51 return cast<quant::QuantizedType>(unwrap(type)).isSigned(); 52 } 53 54 MlirType mlirQuantizedTypeGetStorageType(MlirType type) { 55 return wrap(cast<quant::QuantizedType>(unwrap(type)).getStorageType()); 56 } 57 58 int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) { 59 return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeMin(); 60 } 61 62 int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) { 63 return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeMax(); 64 } 65 66 unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) { 67 return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeIntegralWidth(); 68 } 69 70 bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type, 71 MlirType candidate) { 72 return cast<quant::QuantizedType>(unwrap(type)) 73 .isCompatibleExpressedType(unwrap(candidate)); 74 } 75 76 MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) { 77 return wrap(quant::QuantizedType::getQuantizedElementType(unwrap(type))); 78 } 79 80 MlirType mlirQuantizedTypeCastFromStorageType(MlirType type, 81 MlirType candidate) { 82 return wrap(cast<quant::QuantizedType>(unwrap(type)) 83 .castFromStorageType(unwrap(candidate))); 84 } 85 86 MlirType mlirQuantizedTypeCastToStorageType(MlirType type) { 87 return wrap(quant::QuantizedType::castToStorageType( 88 cast<quant::QuantizedType>(unwrap(type)))); 89 } 90 91 MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type, 92 MlirType candidate) { 93 return wrap(cast<quant::QuantizedType>(unwrap(type)) 94 .castFromExpressedType(unwrap(candidate))); 95 } 96 97 MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) { 98 return wrap(quant::QuantizedType::castToExpressedType(unwrap(type))); 99 } 100 101 MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type, 102 MlirType candidate) { 103 return wrap(cast<quant::QuantizedType>(unwrap(type)) 104 .castExpressedToStorageType(unwrap(candidate))); 105 } 106 107 //===---------------------------------------------------------------------===// 108 // AnyQuantizedType 109 //===---------------------------------------------------------------------===// 110 111 bool mlirTypeIsAAnyQuantizedType(MlirType type) { 112 return isa<quant::AnyQuantizedType>(unwrap(type)); 113 } 114 115 MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType, 116 MlirType expressedType, int64_t storageTypeMin, 117 int64_t storageTypeMax) { 118 return wrap(quant::AnyQuantizedType::get(flags, unwrap(storageType), 119 unwrap(expressedType), 120 storageTypeMin, storageTypeMax)); 121 } 122 123 //===---------------------------------------------------------------------===// 124 // UniformQuantizedType 125 //===---------------------------------------------------------------------===// 126 127 bool mlirTypeIsAUniformQuantizedType(MlirType type) { 128 return isa<quant::UniformQuantizedType>(unwrap(type)); 129 } 130 131 MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType, 132 MlirType expressedType, double scale, 133 int64_t zeroPoint, int64_t storageTypeMin, 134 int64_t storageTypeMax) { 135 return wrap(quant::UniformQuantizedType::get( 136 flags, unwrap(storageType), unwrap(expressedType), scale, zeroPoint, 137 storageTypeMin, storageTypeMax)); 138 } 139 140 double mlirUniformQuantizedTypeGetScale(MlirType type) { 141 return cast<quant::UniformQuantizedType>(unwrap(type)).getScale(); 142 } 143 144 int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) { 145 return cast<quant::UniformQuantizedType>(unwrap(type)).getZeroPoint(); 146 } 147 148 bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) { 149 return cast<quant::UniformQuantizedType>(unwrap(type)).isFixedPoint(); 150 } 151 152 //===---------------------------------------------------------------------===// 153 // UniformQuantizedPerAxisType 154 //===---------------------------------------------------------------------===// 155 156 bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) { 157 return isa<quant::UniformQuantizedPerAxisType>(unwrap(type)); 158 } 159 160 MlirType mlirUniformQuantizedPerAxisTypeGet( 161 unsigned flags, MlirType storageType, MlirType expressedType, 162 intptr_t nDims, double *scales, int64_t *zeroPoints, 163 int32_t quantizedDimension, int64_t storageTypeMin, 164 int64_t storageTypeMax) { 165 return wrap(quant::UniformQuantizedPerAxisType::get( 166 flags, unwrap(storageType), unwrap(expressedType), 167 llvm::ArrayRef(scales, nDims), llvm::ArrayRef(zeroPoints, nDims), 168 quantizedDimension, storageTypeMin, storageTypeMax)); 169 } 170 171 intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) { 172 return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)) 173 .getScales() 174 .size(); 175 } 176 177 double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) { 178 return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)) 179 .getScales()[pos]; 180 } 181 182 int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type, 183 intptr_t pos) { 184 return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)) 185 .getZeroPoints()[pos]; 186 } 187 188 int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) { 189 return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)) 190 .getQuantizedDimension(); 191 } 192 193 bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) { 194 return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)).isFixedPoint(); 195 } 196 197 //===---------------------------------------------------------------------===// 198 // CalibratedQuantizedType 199 //===---------------------------------------------------------------------===// 200 201 bool mlirTypeIsACalibratedQuantizedType(MlirType type) { 202 return isa<quant::CalibratedQuantizedType>(unwrap(type)); 203 } 204 205 MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, 206 double max) { 207 return wrap( 208 quant::CalibratedQuantizedType::get(unwrap(expressedType), min, max)); 209 } 210 211 double mlirCalibratedQuantizedTypeGetMin(MlirType type) { 212 return cast<quant::CalibratedQuantizedType>(unwrap(type)).getMin(); 213 } 214 215 double mlirCalibratedQuantizedTypeGetMax(MlirType type) { 216 return cast<quant::CalibratedQuantizedType>(unwrap(type)).getMax(); 217 } 218