xref: /llvm-project/mlir/lib/Bindings/Python/DialectQuant.cpp (revision 285a229f205ae67dca48c8eac8206a115320c677)
1 //===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/Quant.h"
10 #include "mlir-c/IR.h"
11 #include "mlir/Bindings/Python/PybindAdaptors.h"
12 #include <cstdint>
13 #include <pybind11/cast.h>
14 #include <pybind11/detail/common.h>
15 #include <pybind11/pybind11.h>
16 #include <vector>
17 
18 namespace py = pybind11;
19 using namespace llvm;
20 using namespace mlir;
21 using namespace mlir::python::adaptors;
22 
23 static void populateDialectQuantSubmodule(const py::module &m) {
24   //===-------------------------------------------------------------------===//
25   // QuantizedType
26   //===-------------------------------------------------------------------===//
27 
28   auto quantizedType =
29       mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType);
30   quantizedType.def_staticmethod(
31       "default_minimum_for_integer",
32       [](bool isSigned, unsigned integralWidth) {
33         return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
34                                                             integralWidth);
35       },
36       "Default minimum value for the integer with the specified signedness and "
37       "bit width.",
38       py::arg("is_signed"), py::arg("integral_width"));
39   quantizedType.def_staticmethod(
40       "default_maximum_for_integer",
41       [](bool isSigned, unsigned integralWidth) {
42         return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
43                                                             integralWidth);
44       },
45       "Default maximum value for the integer with the specified signedness and "
46       "bit width.",
47       py::arg("is_signed"), py::arg("integral_width"));
48   quantizedType.def_property_readonly(
49       "expressed_type",
50       [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); },
51       "Type expressed by this quantized type.");
52   quantizedType.def_property_readonly(
53       "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); },
54       "Flags of this quantized type (named accessors should be preferred to "
55       "this)");
56   quantizedType.def_property_readonly(
57       "is_signed",
58       [](MlirType type) { return mlirQuantizedTypeIsSigned(type); },
59       "Signedness of this quantized type.");
60   quantizedType.def_property_readonly(
61       "storage_type",
62       [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); },
63       "Storage type backing this quantized type.");
64   quantizedType.def_property_readonly(
65       "storage_type_min",
66       [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); },
67       "The minimum value held by the storage type of this quantized type.");
68   quantizedType.def_property_readonly(
69       "storage_type_max",
70       [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); },
71       "The maximum value held by the storage type of this quantized type.");
72   quantizedType.def_property_readonly(
73       "storage_type_integral_width",
74       [](MlirType type) {
75         return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
76       },
77       "The bitwidth of the storage type of this quantized type.");
78   quantizedType.def(
79       "is_compatible_expressed_type",
80       [](MlirType type, MlirType candidate) {
81         return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
82       },
83       "Checks whether the candidate type can be expressed by this quantized "
84       "type.",
85       py::arg("candidate"));
86   quantizedType.def_property_readonly(
87       "quantized_element_type",
88       [](MlirType type) {
89         return mlirQuantizedTypeGetQuantizedElementType(type);
90       },
91       "Element type of this quantized type expressed as quantized type.");
92   quantizedType.def(
93       "cast_from_storage_type",
94       [](MlirType type, MlirType candidate) {
95         MlirType castResult =
96             mlirQuantizedTypeCastFromStorageType(type, candidate);
97         if (!mlirTypeIsNull(castResult))
98           return castResult;
99         throw py::type_error("Invalid cast.");
100       },
101       "Casts from a type based on the storage type of this quantized type to a "
102       "corresponding type based on the quantized type. Raises TypeError if the "
103       "cast is not valid.",
104       py::arg("candidate"));
105   quantizedType.def_staticmethod(
106       "cast_to_storage_type",
107       [](MlirType type) {
108         MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
109         if (!mlirTypeIsNull(castResult))
110           return castResult;
111         throw py::type_error("Invalid cast.");
112       },
113       "Casts from a type based on a quantized type to a corresponding type "
114       "based on the storage type of this quantized type. Raises TypeError if "
115       "the cast is not valid.",
116       py::arg("type"));
117   quantizedType.def(
118       "cast_from_expressed_type",
119       [](MlirType type, MlirType candidate) {
120         MlirType castResult =
121             mlirQuantizedTypeCastFromExpressedType(type, candidate);
122         if (!mlirTypeIsNull(castResult))
123           return castResult;
124         throw py::type_error("Invalid cast.");
125       },
126       "Casts from a type based on the expressed type of this quantized type to "
127       "a corresponding type based on the quantized type. Raises TypeError if "
128       "the cast is not valid.",
129       py::arg("candidate"));
130   quantizedType.def_staticmethod(
131       "cast_to_expressed_type",
132       [](MlirType type) {
133         MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
134         if (!mlirTypeIsNull(castResult))
135           return castResult;
136         throw py::type_error("Invalid cast.");
137       },
138       "Casts from a type based on a quantized type to a corresponding type "
139       "based on the expressed type of this quantized type. Raises TypeError if "
140       "the cast is not valid.",
141       py::arg("type"));
142   quantizedType.def(
143       "cast_expressed_to_storage_type",
144       [](MlirType type, MlirType candidate) {
145         MlirType castResult =
146             mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
147         if (!mlirTypeIsNull(castResult))
148           return castResult;
149         throw py::type_error("Invalid cast.");
150       },
151       "Casts from a type based on the expressed type of this quantized type to "
152       "a corresponding type based on the storage type. Raises TypeError if the "
153       "cast is not valid.",
154       py::arg("candidate"));
155 
156   quantizedType.get_class().attr("FLAG_SIGNED") =
157       mlirQuantizedTypeGetSignedFlag();
158 
159   //===-------------------------------------------------------------------===//
160   // AnyQuantizedType
161   //===-------------------------------------------------------------------===//
162 
163   auto anyQuantizedType =
164       mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType,
165                          quantizedType.get_class());
166   anyQuantizedType.def_classmethod(
167       "get",
168       [](py::object cls, unsigned flags, MlirType storageType,
169          MlirType expressedType, int64_t storageTypeMin,
170          int64_t storageTypeMax) {
171         return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
172                                            storageTypeMin, storageTypeMax));
173       },
174       "Gets an instance of AnyQuantizedType in the same context as the "
175       "provided storage type.",
176       py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
177       py::arg("expressed_type"), py::arg("storage_type_min"),
178       py::arg("storage_type_max"));
179 
180   //===-------------------------------------------------------------------===//
181   // UniformQuantizedType
182   //===-------------------------------------------------------------------===//
183 
184   auto uniformQuantizedType = mlir_type_subclass(
185       m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType,
186       quantizedType.get_class());
187   uniformQuantizedType.def_classmethod(
188       "get",
189       [](py::object cls, unsigned flags, MlirType storageType,
190          MlirType expressedType, double scale, int64_t zeroPoint,
191          int64_t storageTypeMin, int64_t storageTypeMax) {
192         return cls(mlirUniformQuantizedTypeGet(flags, storageType,
193                                                expressedType, scale, zeroPoint,
194                                                storageTypeMin, storageTypeMax));
195       },
196       "Gets an instance of UniformQuantizedType in the same context as the "
197       "provided storage type.",
198       py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
199       py::arg("expressed_type"), py::arg("scale"), py::arg("zero_point"),
200       py::arg("storage_type_min"), py::arg("storage_type_max"));
201   uniformQuantizedType.def_property_readonly(
202       "scale",
203       [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); },
204       "The scale designates the difference between the real values "
205       "corresponding to consecutive quantized values differing by 1.");
206   uniformQuantizedType.def_property_readonly(
207       "zero_point",
208       [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); },
209       "The storage value corresponding to the real value 0 in the affine "
210       "equation.");
211   uniformQuantizedType.def_property_readonly(
212       "is_fixed_point",
213       [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); },
214       "Fixed point values are real numbers divided by a scale.");
215 
216   //===-------------------------------------------------------------------===//
217   // UniformQuantizedPerAxisType
218   //===-------------------------------------------------------------------===//
219   auto uniformQuantizedPerAxisType = mlir_type_subclass(
220       m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType,
221       quantizedType.get_class());
222   uniformQuantizedPerAxisType.def_classmethod(
223       "get",
224       [](py::object cls, unsigned flags, MlirType storageType,
225          MlirType expressedType, std::vector<double> scales,
226          std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
227          int64_t storageTypeMin, int64_t storageTypeMax) {
228         if (scales.size() != zeroPoints.size())
229           throw py::value_error(
230               "Mismatching number of scales and zero points.");
231         auto nDims = static_cast<intptr_t>(scales.size());
232         return cls(mlirUniformQuantizedPerAxisTypeGet(
233             flags, storageType, expressedType, nDims, scales.data(),
234             zeroPoints.data(), quantizedDimension, storageTypeMin,
235             storageTypeMax));
236       },
237       "Gets an instance of UniformQuantizedPerAxisType in the same context as "
238       "the provided storage type.",
239       py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
240       py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"),
241       py::arg("quantized_dimension"), py::arg("storage_type_min"),
242       py::arg("storage_type_max"));
243   uniformQuantizedPerAxisType.def_property_readonly(
244       "scales",
245       [](MlirType type) {
246         intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
247         std::vector<double> scales;
248         scales.reserve(nDim);
249         for (intptr_t i = 0; i < nDim; ++i) {
250           double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
251           scales.push_back(scale);
252         }
253       },
254       "The scales designate the difference between the real values "
255       "corresponding to consecutive quantized values differing by 1. The ith "
256       "scale corresponds to the ith slice in the quantized_dimension.");
257   uniformQuantizedPerAxisType.def_property_readonly(
258       "zero_points",
259       [](MlirType type) {
260         intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
261         std::vector<int64_t> zeroPoints;
262         zeroPoints.reserve(nDim);
263         for (intptr_t i = 0; i < nDim; ++i) {
264           int64_t zeroPoint =
265               mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
266           zeroPoints.push_back(zeroPoint);
267         }
268       },
269       "the storage values corresponding to the real value 0 in the affine "
270       "equation. The ith zero point corresponds to the ith slice in the "
271       "quantized_dimension.");
272   uniformQuantizedPerAxisType.def_property_readonly(
273       "quantized_dimension",
274       [](MlirType type) {
275         return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
276       },
277       "Specifies the dimension of the shape that the scales and zero points "
278       "correspond to.");
279   uniformQuantizedPerAxisType.def_property_readonly(
280       "is_fixed_point",
281       [](MlirType type) {
282         return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
283       },
284       "Fixed point values are real numbers divided by a scale.");
285 
286   //===-------------------------------------------------------------------===//
287   // CalibratedQuantizedType
288   //===-------------------------------------------------------------------===//
289 
290   auto calibratedQuantizedType = mlir_type_subclass(
291       m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType,
292       quantizedType.get_class());
293   calibratedQuantizedType.def_classmethod(
294       "get",
295       [](py::object cls, MlirType expressedType, double min, double max) {
296         return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
297       },
298       "Gets an instance of CalibratedQuantizedType in the same context as the "
299       "provided expressed type.",
300       py::arg("cls"), py::arg("expressed_type"), py::arg("min"),
301       py::arg("max"));
302   calibratedQuantizedType.def_property_readonly("min", [](MlirType type) {
303     return mlirCalibratedQuantizedTypeGetMin(type);
304   });
305   calibratedQuantizedType.def_property_readonly("max", [](MlirType type) {
306     return mlirCalibratedQuantizedTypeGetMax(type);
307   });
308 }
309 
310 PYBIND11_MODULE(_mlirDialectsQuant, m) {
311   m.doc() = "MLIR Quantization dialect";
312 
313   populateDialectQuantSubmodule(m);
314 }
315