1 //===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Dialect/Quant.h" 10 #include "mlir-c/IR.h" 11 #include "mlir/Bindings/Python/PybindAdaptors.h" 12 #include <cstdint> 13 #include <pybind11/cast.h> 14 #include <pybind11/detail/common.h> 15 #include <pybind11/pybind11.h> 16 #include <vector> 17 18 namespace py = pybind11; 19 using namespace llvm; 20 using namespace mlir; 21 using namespace mlir::python::adaptors; 22 23 static void populateDialectQuantSubmodule(const py::module &m) { 24 //===-------------------------------------------------------------------===// 25 // QuantizedType 26 //===-------------------------------------------------------------------===// 27 28 auto quantizedType = 29 mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType); 30 quantizedType.def_staticmethod( 31 "default_minimum_for_integer", 32 [](bool isSigned, unsigned integralWidth) { 33 return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned, 34 integralWidth); 35 }, 36 "Default minimum value for the integer with the specified signedness and " 37 "bit width.", 38 py::arg("is_signed"), py::arg("integral_width")); 39 quantizedType.def_staticmethod( 40 "default_maximum_for_integer", 41 [](bool isSigned, unsigned integralWidth) { 42 return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned, 43 integralWidth); 44 }, 45 "Default maximum value for the integer with the specified signedness and " 46 "bit width.", 47 py::arg("is_signed"), py::arg("integral_width")); 48 quantizedType.def_property_readonly( 49 "expressed_type", 50 [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); }, 51 "Type expressed by this quantized type."); 52 quantizedType.def_property_readonly( 53 "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); }, 54 "Flags of this quantized type (named accessors should be preferred to " 55 "this)"); 56 quantizedType.def_property_readonly( 57 "is_signed", 58 [](MlirType type) { return mlirQuantizedTypeIsSigned(type); }, 59 "Signedness of this quantized type."); 60 quantizedType.def_property_readonly( 61 "storage_type", 62 [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); }, 63 "Storage type backing this quantized type."); 64 quantizedType.def_property_readonly( 65 "storage_type_min", 66 [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); }, 67 "The minimum value held by the storage type of this quantized type."); 68 quantizedType.def_property_readonly( 69 "storage_type_max", 70 [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); }, 71 "The maximum value held by the storage type of this quantized type."); 72 quantizedType.def_property_readonly( 73 "storage_type_integral_width", 74 [](MlirType type) { 75 return mlirQuantizedTypeGetStorageTypeIntegralWidth(type); 76 }, 77 "The bitwidth of the storage type of this quantized type."); 78 quantizedType.def( 79 "is_compatible_expressed_type", 80 [](MlirType type, MlirType candidate) { 81 return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate); 82 }, 83 "Checks whether the candidate type can be expressed by this quantized " 84 "type.", 85 py::arg("candidate")); 86 quantizedType.def_property_readonly( 87 "quantized_element_type", 88 [](MlirType type) { 89 return mlirQuantizedTypeGetQuantizedElementType(type); 90 }, 91 "Element type of this quantized type expressed as quantized type."); 92 quantizedType.def( 93 "cast_from_storage_type", 94 [](MlirType type, MlirType candidate) { 95 MlirType castResult = 96 mlirQuantizedTypeCastFromStorageType(type, candidate); 97 if (!mlirTypeIsNull(castResult)) 98 return castResult; 99 throw py::type_error("Invalid cast."); 100 }, 101 "Casts from a type based on the storage type of this quantized type to a " 102 "corresponding type based on the quantized type. Raises TypeError if the " 103 "cast is not valid.", 104 py::arg("candidate")); 105 quantizedType.def_staticmethod( 106 "cast_to_storage_type", 107 [](MlirType type) { 108 MlirType castResult = mlirQuantizedTypeCastToStorageType(type); 109 if (!mlirTypeIsNull(castResult)) 110 return castResult; 111 throw py::type_error("Invalid cast."); 112 }, 113 "Casts from a type based on a quantized type to a corresponding type " 114 "based on the storage type of this quantized type. Raises TypeError if " 115 "the cast is not valid.", 116 py::arg("type")); 117 quantizedType.def( 118 "cast_from_expressed_type", 119 [](MlirType type, MlirType candidate) { 120 MlirType castResult = 121 mlirQuantizedTypeCastFromExpressedType(type, candidate); 122 if (!mlirTypeIsNull(castResult)) 123 return castResult; 124 throw py::type_error("Invalid cast."); 125 }, 126 "Casts from a type based on the expressed type of this quantized type to " 127 "a corresponding type based on the quantized type. Raises TypeError if " 128 "the cast is not valid.", 129 py::arg("candidate")); 130 quantizedType.def_staticmethod( 131 "cast_to_expressed_type", 132 [](MlirType type) { 133 MlirType castResult = mlirQuantizedTypeCastToExpressedType(type); 134 if (!mlirTypeIsNull(castResult)) 135 return castResult; 136 throw py::type_error("Invalid cast."); 137 }, 138 "Casts from a type based on a quantized type to a corresponding type " 139 "based on the expressed type of this quantized type. Raises TypeError if " 140 "the cast is not valid.", 141 py::arg("type")); 142 quantizedType.def( 143 "cast_expressed_to_storage_type", 144 [](MlirType type, MlirType candidate) { 145 MlirType castResult = 146 mlirQuantizedTypeCastExpressedToStorageType(type, candidate); 147 if (!mlirTypeIsNull(castResult)) 148 return castResult; 149 throw py::type_error("Invalid cast."); 150 }, 151 "Casts from a type based on the expressed type of this quantized type to " 152 "a corresponding type based on the storage type. Raises TypeError if the " 153 "cast is not valid.", 154 py::arg("candidate")); 155 156 quantizedType.get_class().attr("FLAG_SIGNED") = 157 mlirQuantizedTypeGetSignedFlag(); 158 159 //===-------------------------------------------------------------------===// 160 // AnyQuantizedType 161 //===-------------------------------------------------------------------===// 162 163 auto anyQuantizedType = 164 mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType, 165 quantizedType.get_class()); 166 anyQuantizedType.def_classmethod( 167 "get", 168 [](py::object cls, unsigned flags, MlirType storageType, 169 MlirType expressedType, int64_t storageTypeMin, 170 int64_t storageTypeMax) { 171 return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType, 172 storageTypeMin, storageTypeMax)); 173 }, 174 "Gets an instance of AnyQuantizedType in the same context as the " 175 "provided storage type.", 176 py::arg("cls"), py::arg("flags"), py::arg("storage_type"), 177 py::arg("expressed_type"), py::arg("storage_type_min"), 178 py::arg("storage_type_max")); 179 180 //===-------------------------------------------------------------------===// 181 // UniformQuantizedType 182 //===-------------------------------------------------------------------===// 183 184 auto uniformQuantizedType = mlir_type_subclass( 185 m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType, 186 quantizedType.get_class()); 187 uniformQuantizedType.def_classmethod( 188 "get", 189 [](py::object cls, unsigned flags, MlirType storageType, 190 MlirType expressedType, double scale, int64_t zeroPoint, 191 int64_t storageTypeMin, int64_t storageTypeMax) { 192 return cls(mlirUniformQuantizedTypeGet(flags, storageType, 193 expressedType, scale, zeroPoint, 194 storageTypeMin, storageTypeMax)); 195 }, 196 "Gets an instance of UniformQuantizedType in the same context as the " 197 "provided storage type.", 198 py::arg("cls"), py::arg("flags"), py::arg("storage_type"), 199 py::arg("expressed_type"), py::arg("scale"), py::arg("zero_point"), 200 py::arg("storage_type_min"), py::arg("storage_type_max")); 201 uniformQuantizedType.def_property_readonly( 202 "scale", 203 [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); }, 204 "The scale designates the difference between the real values " 205 "corresponding to consecutive quantized values differing by 1."); 206 uniformQuantizedType.def_property_readonly( 207 "zero_point", 208 [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); }, 209 "The storage value corresponding to the real value 0 in the affine " 210 "equation."); 211 uniformQuantizedType.def_property_readonly( 212 "is_fixed_point", 213 [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); }, 214 "Fixed point values are real numbers divided by a scale."); 215 216 //===-------------------------------------------------------------------===// 217 // UniformQuantizedPerAxisType 218 //===-------------------------------------------------------------------===// 219 auto uniformQuantizedPerAxisType = mlir_type_subclass( 220 m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType, 221 quantizedType.get_class()); 222 uniformQuantizedPerAxisType.def_classmethod( 223 "get", 224 [](py::object cls, unsigned flags, MlirType storageType, 225 MlirType expressedType, std::vector<double> scales, 226 std::vector<int64_t> zeroPoints, int32_t quantizedDimension, 227 int64_t storageTypeMin, int64_t storageTypeMax) { 228 if (scales.size() != zeroPoints.size()) 229 throw py::value_error( 230 "Mismatching number of scales and zero points."); 231 auto nDims = static_cast<intptr_t>(scales.size()); 232 return cls(mlirUniformQuantizedPerAxisTypeGet( 233 flags, storageType, expressedType, nDims, scales.data(), 234 zeroPoints.data(), quantizedDimension, storageTypeMin, 235 storageTypeMax)); 236 }, 237 "Gets an instance of UniformQuantizedPerAxisType in the same context as " 238 "the provided storage type.", 239 py::arg("cls"), py::arg("flags"), py::arg("storage_type"), 240 py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"), 241 py::arg("quantized_dimension"), py::arg("storage_type_min"), 242 py::arg("storage_type_max")); 243 uniformQuantizedPerAxisType.def_property_readonly( 244 "scales", 245 [](MlirType type) { 246 intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); 247 std::vector<double> scales; 248 scales.reserve(nDim); 249 for (intptr_t i = 0; i < nDim; ++i) { 250 double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i); 251 scales.push_back(scale); 252 } 253 }, 254 "The scales designate the difference between the real values " 255 "corresponding to consecutive quantized values differing by 1. The ith " 256 "scale corresponds to the ith slice in the quantized_dimension."); 257 uniformQuantizedPerAxisType.def_property_readonly( 258 "zero_points", 259 [](MlirType type) { 260 intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); 261 std::vector<int64_t> zeroPoints; 262 zeroPoints.reserve(nDim); 263 for (intptr_t i = 0; i < nDim; ++i) { 264 int64_t zeroPoint = 265 mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i); 266 zeroPoints.push_back(zeroPoint); 267 } 268 }, 269 "the storage values corresponding to the real value 0 in the affine " 270 "equation. The ith zero point corresponds to the ith slice in the " 271 "quantized_dimension."); 272 uniformQuantizedPerAxisType.def_property_readonly( 273 "quantized_dimension", 274 [](MlirType type) { 275 return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type); 276 }, 277 "Specifies the dimension of the shape that the scales and zero points " 278 "correspond to."); 279 uniformQuantizedPerAxisType.def_property_readonly( 280 "is_fixed_point", 281 [](MlirType type) { 282 return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type); 283 }, 284 "Fixed point values are real numbers divided by a scale."); 285 286 //===-------------------------------------------------------------------===// 287 // CalibratedQuantizedType 288 //===-------------------------------------------------------------------===// 289 290 auto calibratedQuantizedType = mlir_type_subclass( 291 m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType, 292 quantizedType.get_class()); 293 calibratedQuantizedType.def_classmethod( 294 "get", 295 [](py::object cls, MlirType expressedType, double min, double max) { 296 return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max)); 297 }, 298 "Gets an instance of CalibratedQuantizedType in the same context as the " 299 "provided expressed type.", 300 py::arg("cls"), py::arg("expressed_type"), py::arg("min"), 301 py::arg("max")); 302 calibratedQuantizedType.def_property_readonly("min", [](MlirType type) { 303 return mlirCalibratedQuantizedTypeGetMin(type); 304 }); 305 calibratedQuantizedType.def_property_readonly("max", [](MlirType type) { 306 return mlirCalibratedQuantizedTypeGetMax(type); 307 }); 308 } 309 310 PYBIND11_MODULE(_mlirDialectsQuant, m) { 311 m.doc() = "MLIR Quantization dialect"; 312 313 populateDialectQuantSubmodule(m); 314 } 315