xref: /llvm-project/mlir/lib/Bindings/Python/DialectQuant.cpp (revision 5cd427477218d8bdb659c6c53a7758f741c3990a)
1 //===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cstdint>
10 #include <vector>
11 
12 #include "mlir-c/Dialect/Quant.h"
13 #include "mlir-c/IR.h"
14 #include "mlir/Bindings/Python/NanobindAdaptors.h"
15 #include "mlir/Bindings/Python/Nanobind.h"
16 
17 namespace nb = nanobind;
18 using namespace llvm;
19 using namespace mlir;
20 using namespace mlir::python::nanobind_adaptors;
21 
22 static void populateDialectQuantSubmodule(const nb::module_ &m) {
23   //===-------------------------------------------------------------------===//
24   // QuantizedType
25   //===-------------------------------------------------------------------===//
26 
27   auto quantizedType =
28       mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType);
29   quantizedType.def_staticmethod(
30       "default_minimum_for_integer",
31       [](bool isSigned, unsigned integralWidth) {
32         return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
33                                                             integralWidth);
34       },
35       "Default minimum value for the integer with the specified signedness and "
36       "bit width.",
37       nb::arg("is_signed"), nb::arg("integral_width"));
38   quantizedType.def_staticmethod(
39       "default_maximum_for_integer",
40       [](bool isSigned, unsigned integralWidth) {
41         return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
42                                                             integralWidth);
43       },
44       "Default maximum value for the integer with the specified signedness and "
45       "bit width.",
46       nb::arg("is_signed"), nb::arg("integral_width"));
47   quantizedType.def_property_readonly(
48       "expressed_type",
49       [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); },
50       "Type expressed by this quantized type.");
51   quantizedType.def_property_readonly(
52       "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); },
53       "Flags of this quantized type (named accessors should be preferred to "
54       "this)");
55   quantizedType.def_property_readonly(
56       "is_signed",
57       [](MlirType type) { return mlirQuantizedTypeIsSigned(type); },
58       "Signedness of this quantized type.");
59   quantizedType.def_property_readonly(
60       "storage_type",
61       [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); },
62       "Storage type backing this quantized type.");
63   quantizedType.def_property_readonly(
64       "storage_type_min",
65       [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); },
66       "The minimum value held by the storage type of this quantized type.");
67   quantizedType.def_property_readonly(
68       "storage_type_max",
69       [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); },
70       "The maximum value held by the storage type of this quantized type.");
71   quantizedType.def_property_readonly(
72       "storage_type_integral_width",
73       [](MlirType type) {
74         return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
75       },
76       "The bitwidth of the storage type of this quantized type.");
77   quantizedType.def(
78       "is_compatible_expressed_type",
79       [](MlirType type, MlirType candidate) {
80         return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
81       },
82       "Checks whether the candidate type can be expressed by this quantized "
83       "type.",
84       nb::arg("candidate"));
85   quantizedType.def_property_readonly(
86       "quantized_element_type",
87       [](MlirType type) {
88         return mlirQuantizedTypeGetQuantizedElementType(type);
89       },
90       "Element type of this quantized type expressed as quantized type.");
91   quantizedType.def(
92       "cast_from_storage_type",
93       [](MlirType type, MlirType candidate) {
94         MlirType castResult =
95             mlirQuantizedTypeCastFromStorageType(type, candidate);
96         if (!mlirTypeIsNull(castResult))
97           return castResult;
98         throw nb::type_error("Invalid cast.");
99       },
100       "Casts from a type based on the storage type of this quantized type to a "
101       "corresponding type based on the quantized type. Raises TypeError if the "
102       "cast is not valid.",
103       nb::arg("candidate"));
104   quantizedType.def_staticmethod(
105       "cast_to_storage_type",
106       [](MlirType type) {
107         MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
108         if (!mlirTypeIsNull(castResult))
109           return castResult;
110         throw nb::type_error("Invalid cast.");
111       },
112       "Casts from a type based on a quantized type to a corresponding type "
113       "based on the storage type of this quantized type. Raises TypeError if "
114       "the cast is not valid.",
115       nb::arg("type"));
116   quantizedType.def(
117       "cast_from_expressed_type",
118       [](MlirType type, MlirType candidate) {
119         MlirType castResult =
120             mlirQuantizedTypeCastFromExpressedType(type, candidate);
121         if (!mlirTypeIsNull(castResult))
122           return castResult;
123         throw nb::type_error("Invalid cast.");
124       },
125       "Casts from a type based on the expressed type of this quantized type to "
126       "a corresponding type based on the quantized type. Raises TypeError if "
127       "the cast is not valid.",
128       nb::arg("candidate"));
129   quantizedType.def_staticmethod(
130       "cast_to_expressed_type",
131       [](MlirType type) {
132         MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
133         if (!mlirTypeIsNull(castResult))
134           return castResult;
135         throw nb::type_error("Invalid cast.");
136       },
137       "Casts from a type based on a quantized type to a corresponding type "
138       "based on the expressed type of this quantized type. Raises TypeError if "
139       "the cast is not valid.",
140       nb::arg("type"));
141   quantizedType.def(
142       "cast_expressed_to_storage_type",
143       [](MlirType type, MlirType candidate) {
144         MlirType castResult =
145             mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
146         if (!mlirTypeIsNull(castResult))
147           return castResult;
148         throw nb::type_error("Invalid cast.");
149       },
150       "Casts from a type based on the expressed type of this quantized type to "
151       "a corresponding type based on the storage type. Raises TypeError if the "
152       "cast is not valid.",
153       nb::arg("candidate"));
154 
155   quantizedType.get_class().attr("FLAG_SIGNED") =
156       mlirQuantizedTypeGetSignedFlag();
157 
158   //===-------------------------------------------------------------------===//
159   // AnyQuantizedType
160   //===-------------------------------------------------------------------===//
161 
162   auto anyQuantizedType =
163       mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType,
164                          quantizedType.get_class());
165   anyQuantizedType.def_classmethod(
166       "get",
167       [](nb::object cls, unsigned flags, MlirType storageType,
168          MlirType expressedType, int64_t storageTypeMin,
169          int64_t storageTypeMax) {
170         return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
171                                            storageTypeMin, storageTypeMax));
172       },
173       "Gets an instance of AnyQuantizedType in the same context as the "
174       "provided storage type.",
175       nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
176       nb::arg("expressed_type"), nb::arg("storage_type_min"),
177       nb::arg("storage_type_max"));
178 
179   //===-------------------------------------------------------------------===//
180   // UniformQuantizedType
181   //===-------------------------------------------------------------------===//
182 
183   auto uniformQuantizedType = mlir_type_subclass(
184       m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType,
185       quantizedType.get_class());
186   uniformQuantizedType.def_classmethod(
187       "get",
188       [](nb::object cls, unsigned flags, MlirType storageType,
189          MlirType expressedType, double scale, int64_t zeroPoint,
190          int64_t storageTypeMin, int64_t storageTypeMax) {
191         return cls(mlirUniformQuantizedTypeGet(flags, storageType,
192                                                expressedType, scale, zeroPoint,
193                                                storageTypeMin, storageTypeMax));
194       },
195       "Gets an instance of UniformQuantizedType in the same context as the "
196       "provided storage type.",
197       nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
198       nb::arg("expressed_type"), nb::arg("scale"), nb::arg("zero_point"),
199       nb::arg("storage_type_min"), nb::arg("storage_type_max"));
200   uniformQuantizedType.def_property_readonly(
201       "scale",
202       [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); },
203       "The scale designates the difference between the real values "
204       "corresponding to consecutive quantized values differing by 1.");
205   uniformQuantizedType.def_property_readonly(
206       "zero_point",
207       [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); },
208       "The storage value corresponding to the real value 0 in the affine "
209       "equation.");
210   uniformQuantizedType.def_property_readonly(
211       "is_fixed_point",
212       [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); },
213       "Fixed point values are real numbers divided by a scale.");
214 
215   //===-------------------------------------------------------------------===//
216   // UniformQuantizedPerAxisType
217   //===-------------------------------------------------------------------===//
218   auto uniformQuantizedPerAxisType = mlir_type_subclass(
219       m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType,
220       quantizedType.get_class());
221   uniformQuantizedPerAxisType.def_classmethod(
222       "get",
223       [](nb::object cls, unsigned flags, MlirType storageType,
224          MlirType expressedType, std::vector<double> scales,
225          std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
226          int64_t storageTypeMin, int64_t storageTypeMax) {
227         if (scales.size() != zeroPoints.size())
228           throw nb::value_error(
229               "Mismatching number of scales and zero points.");
230         auto nDims = static_cast<intptr_t>(scales.size());
231         return cls(mlirUniformQuantizedPerAxisTypeGet(
232             flags, storageType, expressedType, nDims, scales.data(),
233             zeroPoints.data(), quantizedDimension, storageTypeMin,
234             storageTypeMax));
235       },
236       "Gets an instance of UniformQuantizedPerAxisType in the same context as "
237       "the provided storage type.",
238       nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
239       nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"),
240       nb::arg("quantized_dimension"), nb::arg("storage_type_min"),
241       nb::arg("storage_type_max"));
242   uniformQuantizedPerAxisType.def_property_readonly(
243       "scales",
244       [](MlirType type) {
245         intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
246         std::vector<double> scales;
247         scales.reserve(nDim);
248         for (intptr_t i = 0; i < nDim; ++i) {
249           double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
250           scales.push_back(scale);
251         }
252         return scales;
253       },
254       "The scales designate the difference between the real values "
255       "corresponding to consecutive quantized values differing by 1. The ith "
256       "scale corresponds to the ith slice in the quantized_dimension.");
257   uniformQuantizedPerAxisType.def_property_readonly(
258       "zero_points",
259       [](MlirType type) {
260         intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
261         std::vector<int64_t> zeroPoints;
262         zeroPoints.reserve(nDim);
263         for (intptr_t i = 0; i < nDim; ++i) {
264           int64_t zeroPoint =
265               mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
266           zeroPoints.push_back(zeroPoint);
267         }
268         return zeroPoints;
269       },
270       "the storage values corresponding to the real value 0 in the affine "
271       "equation. The ith zero point corresponds to the ith slice in the "
272       "quantized_dimension.");
273   uniformQuantizedPerAxisType.def_property_readonly(
274       "quantized_dimension",
275       [](MlirType type) {
276         return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
277       },
278       "Specifies the dimension of the shape that the scales and zero points "
279       "correspond to.");
280   uniformQuantizedPerAxisType.def_property_readonly(
281       "is_fixed_point",
282       [](MlirType type) {
283         return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
284       },
285       "Fixed point values are real numbers divided by a scale.");
286 
287   //===-------------------------------------------------------------------===//
288   // CalibratedQuantizedType
289   //===-------------------------------------------------------------------===//
290 
291   auto calibratedQuantizedType = mlir_type_subclass(
292       m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType,
293       quantizedType.get_class());
294   calibratedQuantizedType.def_classmethod(
295       "get",
296       [](nb::object cls, MlirType expressedType, double min, double max) {
297         return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
298       },
299       "Gets an instance of CalibratedQuantizedType in the same context as the "
300       "provided expressed type.",
301       nb::arg("cls"), nb::arg("expressed_type"), nb::arg("min"),
302       nb::arg("max"));
303   calibratedQuantizedType.def_property_readonly("min", [](MlirType type) {
304     return mlirCalibratedQuantizedTypeGetMin(type);
305   });
306   calibratedQuantizedType.def_property_readonly("max", [](MlirType type) {
307     return mlirCalibratedQuantizedTypeGetMax(type);
308   });
309 }
310 
311 NB_MODULE(_mlirDialectsQuant, m) {
312   m.doc() = "MLIR Quantization dialect";
313 
314   populateDialectQuantSubmodule(m);
315 }
316