1 //===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <cstdint> 10 #include <vector> 11 12 #include "mlir-c/Dialect/Quant.h" 13 #include "mlir-c/IR.h" 14 #include "mlir/Bindings/Python/NanobindAdaptors.h" 15 #include "mlir/Bindings/Python/Nanobind.h" 16 17 namespace nb = nanobind; 18 using namespace llvm; 19 using namespace mlir; 20 using namespace mlir::python::nanobind_adaptors; 21 22 static void populateDialectQuantSubmodule(const nb::module_ &m) { 23 //===-------------------------------------------------------------------===// 24 // QuantizedType 25 //===-------------------------------------------------------------------===// 26 27 auto quantizedType = 28 mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType); 29 quantizedType.def_staticmethod( 30 "default_minimum_for_integer", 31 [](bool isSigned, unsigned integralWidth) { 32 return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned, 33 integralWidth); 34 }, 35 "Default minimum value for the integer with the specified signedness and " 36 "bit width.", 37 nb::arg("is_signed"), nb::arg("integral_width")); 38 quantizedType.def_staticmethod( 39 "default_maximum_for_integer", 40 [](bool isSigned, unsigned integralWidth) { 41 return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned, 42 integralWidth); 43 }, 44 "Default maximum value for the integer with the specified signedness and " 45 "bit width.", 46 nb::arg("is_signed"), nb::arg("integral_width")); 47 quantizedType.def_property_readonly( 48 "expressed_type", 49 [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); }, 50 "Type expressed by this quantized type."); 51 quantizedType.def_property_readonly( 52 "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); }, 53 "Flags of this quantized type (named accessors should be preferred to " 54 "this)"); 55 quantizedType.def_property_readonly( 56 "is_signed", 57 [](MlirType type) { return mlirQuantizedTypeIsSigned(type); }, 58 "Signedness of this quantized type."); 59 quantizedType.def_property_readonly( 60 "storage_type", 61 [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); }, 62 "Storage type backing this quantized type."); 63 quantizedType.def_property_readonly( 64 "storage_type_min", 65 [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); }, 66 "The minimum value held by the storage type of this quantized type."); 67 quantizedType.def_property_readonly( 68 "storage_type_max", 69 [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); }, 70 "The maximum value held by the storage type of this quantized type."); 71 quantizedType.def_property_readonly( 72 "storage_type_integral_width", 73 [](MlirType type) { 74 return mlirQuantizedTypeGetStorageTypeIntegralWidth(type); 75 }, 76 "The bitwidth of the storage type of this quantized type."); 77 quantizedType.def( 78 "is_compatible_expressed_type", 79 [](MlirType type, MlirType candidate) { 80 return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate); 81 }, 82 "Checks whether the candidate type can be expressed by this quantized " 83 "type.", 84 nb::arg("candidate")); 85 quantizedType.def_property_readonly( 86 "quantized_element_type", 87 [](MlirType type) { 88 return mlirQuantizedTypeGetQuantizedElementType(type); 89 }, 90 "Element type of this quantized type expressed as quantized type."); 91 quantizedType.def( 92 "cast_from_storage_type", 93 [](MlirType type, MlirType candidate) { 94 MlirType castResult = 95 mlirQuantizedTypeCastFromStorageType(type, candidate); 96 if (!mlirTypeIsNull(castResult)) 97 return castResult; 98 throw nb::type_error("Invalid cast."); 99 }, 100 "Casts from a type based on the storage type of this quantized type to a " 101 "corresponding type based on the quantized type. Raises TypeError if the " 102 "cast is not valid.", 103 nb::arg("candidate")); 104 quantizedType.def_staticmethod( 105 "cast_to_storage_type", 106 [](MlirType type) { 107 MlirType castResult = mlirQuantizedTypeCastToStorageType(type); 108 if (!mlirTypeIsNull(castResult)) 109 return castResult; 110 throw nb::type_error("Invalid cast."); 111 }, 112 "Casts from a type based on a quantized type to a corresponding type " 113 "based on the storage type of this quantized type. Raises TypeError if " 114 "the cast is not valid.", 115 nb::arg("type")); 116 quantizedType.def( 117 "cast_from_expressed_type", 118 [](MlirType type, MlirType candidate) { 119 MlirType castResult = 120 mlirQuantizedTypeCastFromExpressedType(type, candidate); 121 if (!mlirTypeIsNull(castResult)) 122 return castResult; 123 throw nb::type_error("Invalid cast."); 124 }, 125 "Casts from a type based on the expressed type of this quantized type to " 126 "a corresponding type based on the quantized type. Raises TypeError if " 127 "the cast is not valid.", 128 nb::arg("candidate")); 129 quantizedType.def_staticmethod( 130 "cast_to_expressed_type", 131 [](MlirType type) { 132 MlirType castResult = mlirQuantizedTypeCastToExpressedType(type); 133 if (!mlirTypeIsNull(castResult)) 134 return castResult; 135 throw nb::type_error("Invalid cast."); 136 }, 137 "Casts from a type based on a quantized type to a corresponding type " 138 "based on the expressed type of this quantized type. Raises TypeError if " 139 "the cast is not valid.", 140 nb::arg("type")); 141 quantizedType.def( 142 "cast_expressed_to_storage_type", 143 [](MlirType type, MlirType candidate) { 144 MlirType castResult = 145 mlirQuantizedTypeCastExpressedToStorageType(type, candidate); 146 if (!mlirTypeIsNull(castResult)) 147 return castResult; 148 throw nb::type_error("Invalid cast."); 149 }, 150 "Casts from a type based on the expressed type of this quantized type to " 151 "a corresponding type based on the storage type. Raises TypeError if the " 152 "cast is not valid.", 153 nb::arg("candidate")); 154 155 quantizedType.get_class().attr("FLAG_SIGNED") = 156 mlirQuantizedTypeGetSignedFlag(); 157 158 //===-------------------------------------------------------------------===// 159 // AnyQuantizedType 160 //===-------------------------------------------------------------------===// 161 162 auto anyQuantizedType = 163 mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType, 164 quantizedType.get_class()); 165 anyQuantizedType.def_classmethod( 166 "get", 167 [](nb::object cls, unsigned flags, MlirType storageType, 168 MlirType expressedType, int64_t storageTypeMin, 169 int64_t storageTypeMax) { 170 return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType, 171 storageTypeMin, storageTypeMax)); 172 }, 173 "Gets an instance of AnyQuantizedType in the same context as the " 174 "provided storage type.", 175 nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), 176 nb::arg("expressed_type"), nb::arg("storage_type_min"), 177 nb::arg("storage_type_max")); 178 179 //===-------------------------------------------------------------------===// 180 // UniformQuantizedType 181 //===-------------------------------------------------------------------===// 182 183 auto uniformQuantizedType = mlir_type_subclass( 184 m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType, 185 quantizedType.get_class()); 186 uniformQuantizedType.def_classmethod( 187 "get", 188 [](nb::object cls, unsigned flags, MlirType storageType, 189 MlirType expressedType, double scale, int64_t zeroPoint, 190 int64_t storageTypeMin, int64_t storageTypeMax) { 191 return cls(mlirUniformQuantizedTypeGet(flags, storageType, 192 expressedType, scale, zeroPoint, 193 storageTypeMin, storageTypeMax)); 194 }, 195 "Gets an instance of UniformQuantizedType in the same context as the " 196 "provided storage type.", 197 nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), 198 nb::arg("expressed_type"), nb::arg("scale"), nb::arg("zero_point"), 199 nb::arg("storage_type_min"), nb::arg("storage_type_max")); 200 uniformQuantizedType.def_property_readonly( 201 "scale", 202 [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); }, 203 "The scale designates the difference between the real values " 204 "corresponding to consecutive quantized values differing by 1."); 205 uniformQuantizedType.def_property_readonly( 206 "zero_point", 207 [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); }, 208 "The storage value corresponding to the real value 0 in the affine " 209 "equation."); 210 uniformQuantizedType.def_property_readonly( 211 "is_fixed_point", 212 [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); }, 213 "Fixed point values are real numbers divided by a scale."); 214 215 //===-------------------------------------------------------------------===// 216 // UniformQuantizedPerAxisType 217 //===-------------------------------------------------------------------===// 218 auto uniformQuantizedPerAxisType = mlir_type_subclass( 219 m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType, 220 quantizedType.get_class()); 221 uniformQuantizedPerAxisType.def_classmethod( 222 "get", 223 [](nb::object cls, unsigned flags, MlirType storageType, 224 MlirType expressedType, std::vector<double> scales, 225 std::vector<int64_t> zeroPoints, int32_t quantizedDimension, 226 int64_t storageTypeMin, int64_t storageTypeMax) { 227 if (scales.size() != zeroPoints.size()) 228 throw nb::value_error( 229 "Mismatching number of scales and zero points."); 230 auto nDims = static_cast<intptr_t>(scales.size()); 231 return cls(mlirUniformQuantizedPerAxisTypeGet( 232 flags, storageType, expressedType, nDims, scales.data(), 233 zeroPoints.data(), quantizedDimension, storageTypeMin, 234 storageTypeMax)); 235 }, 236 "Gets an instance of UniformQuantizedPerAxisType in the same context as " 237 "the provided storage type.", 238 nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), 239 nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"), 240 nb::arg("quantized_dimension"), nb::arg("storage_type_min"), 241 nb::arg("storage_type_max")); 242 uniformQuantizedPerAxisType.def_property_readonly( 243 "scales", 244 [](MlirType type) { 245 intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); 246 std::vector<double> scales; 247 scales.reserve(nDim); 248 for (intptr_t i = 0; i < nDim; ++i) { 249 double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i); 250 scales.push_back(scale); 251 } 252 return scales; 253 }, 254 "The scales designate the difference between the real values " 255 "corresponding to consecutive quantized values differing by 1. The ith " 256 "scale corresponds to the ith slice in the quantized_dimension."); 257 uniformQuantizedPerAxisType.def_property_readonly( 258 "zero_points", 259 [](MlirType type) { 260 intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); 261 std::vector<int64_t> zeroPoints; 262 zeroPoints.reserve(nDim); 263 for (intptr_t i = 0; i < nDim; ++i) { 264 int64_t zeroPoint = 265 mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i); 266 zeroPoints.push_back(zeroPoint); 267 } 268 return zeroPoints; 269 }, 270 "the storage values corresponding to the real value 0 in the affine " 271 "equation. The ith zero point corresponds to the ith slice in the " 272 "quantized_dimension."); 273 uniformQuantizedPerAxisType.def_property_readonly( 274 "quantized_dimension", 275 [](MlirType type) { 276 return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type); 277 }, 278 "Specifies the dimension of the shape that the scales and zero points " 279 "correspond to."); 280 uniformQuantizedPerAxisType.def_property_readonly( 281 "is_fixed_point", 282 [](MlirType type) { 283 return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type); 284 }, 285 "Fixed point values are real numbers divided by a scale."); 286 287 //===-------------------------------------------------------------------===// 288 // CalibratedQuantizedType 289 //===-------------------------------------------------------------------===// 290 291 auto calibratedQuantizedType = mlir_type_subclass( 292 m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType, 293 quantizedType.get_class()); 294 calibratedQuantizedType.def_classmethod( 295 "get", 296 [](nb::object cls, MlirType expressedType, double min, double max) { 297 return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max)); 298 }, 299 "Gets an instance of CalibratedQuantizedType in the same context as the " 300 "provided expressed type.", 301 nb::arg("cls"), nb::arg("expressed_type"), nb::arg("min"), 302 nb::arg("max")); 303 calibratedQuantizedType.def_property_readonly("min", [](MlirType type) { 304 return mlirCalibratedQuantizedTypeGetMin(type); 305 }); 306 calibratedQuantizedType.def_property_readonly("max", [](MlirType type) { 307 return mlirCalibratedQuantizedTypeGetMax(type); 308 }); 309 } 310 311 NB_MODULE(_mlirDialectsQuant, m) { 312 m.doc() = "MLIR Quantization dialect"; 313 314 populateDialectQuantSubmodule(m); 315 } 316