166d4090dSAlex Zinenko //===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===// 266d4090dSAlex Zinenko // 366d4090dSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 466d4090dSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 566d4090dSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 666d4090dSAlex Zinenko // 766d4090dSAlex Zinenko //===----------------------------------------------------------------------===// 866d4090dSAlex Zinenko 9285a229fSMehdi Amini #include <cstdint> 10285a229fSMehdi Amini #include <vector> 1166d4090dSAlex Zinenko 12*5cd42747SPeter Hawkins #include "mlir-c/Dialect/Quant.h" 13*5cd42747SPeter Hawkins #include "mlir-c/IR.h" 14*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/NanobindAdaptors.h" 15*5cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h" 16*5cd42747SPeter Hawkins 17*5cd42747SPeter Hawkins namespace nb = nanobind; 1866d4090dSAlex Zinenko using namespace llvm; 1966d4090dSAlex Zinenko using namespace mlir; 20*5cd42747SPeter Hawkins using namespace mlir::python::nanobind_adaptors; 2166d4090dSAlex Zinenko 22*5cd42747SPeter Hawkins static void populateDialectQuantSubmodule(const nb::module_ &m) { 2366d4090dSAlex Zinenko //===-------------------------------------------------------------------===// 2466d4090dSAlex Zinenko // QuantizedType 2566d4090dSAlex Zinenko //===-------------------------------------------------------------------===// 2666d4090dSAlex Zinenko 2795ddbed9SAlex Zinenko auto quantizedType = 2895ddbed9SAlex Zinenko mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType); 2966d4090dSAlex Zinenko quantizedType.def_staticmethod( 3066d4090dSAlex Zinenko "default_minimum_for_integer", 3166d4090dSAlex Zinenko [](bool isSigned, unsigned integralWidth) { 3266d4090dSAlex Zinenko return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned, 3366d4090dSAlex Zinenko integralWidth); 3466d4090dSAlex Zinenko }, 3566d4090dSAlex Zinenko "Default minimum value for the integer with the specified signedness and " 3666d4090dSAlex Zinenko "bit width.", 37*5cd42747SPeter Hawkins nb::arg("is_signed"), nb::arg("integral_width")); 3866d4090dSAlex Zinenko quantizedType.def_staticmethod( 3966d4090dSAlex Zinenko "default_maximum_for_integer", 4066d4090dSAlex Zinenko [](bool isSigned, unsigned integralWidth) { 4166d4090dSAlex Zinenko return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned, 4266d4090dSAlex Zinenko integralWidth); 4366d4090dSAlex Zinenko }, 4466d4090dSAlex Zinenko "Default maximum value for the integer with the specified signedness and " 4566d4090dSAlex Zinenko "bit width.", 46*5cd42747SPeter Hawkins nb::arg("is_signed"), nb::arg("integral_width")); 4766d4090dSAlex Zinenko quantizedType.def_property_readonly( 4866d4090dSAlex Zinenko "expressed_type", 4966d4090dSAlex Zinenko [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); }, 5066d4090dSAlex Zinenko "Type expressed by this quantized type."); 5166d4090dSAlex Zinenko quantizedType.def_property_readonly( 5266d4090dSAlex Zinenko "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); }, 5366d4090dSAlex Zinenko "Flags of this quantized type (named accessors should be preferred to " 5466d4090dSAlex Zinenko "this)"); 5566d4090dSAlex Zinenko quantizedType.def_property_readonly( 5666d4090dSAlex Zinenko "is_signed", 5766d4090dSAlex Zinenko [](MlirType type) { return mlirQuantizedTypeIsSigned(type); }, 5866d4090dSAlex Zinenko "Signedness of this quantized type."); 5966d4090dSAlex Zinenko quantizedType.def_property_readonly( 6066d4090dSAlex Zinenko "storage_type", 6166d4090dSAlex Zinenko [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); }, 6266d4090dSAlex Zinenko "Storage type backing this quantized type."); 6366d4090dSAlex Zinenko quantizedType.def_property_readonly( 6466d4090dSAlex Zinenko "storage_type_min", 6566d4090dSAlex Zinenko [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); }, 6666d4090dSAlex Zinenko "The minimum value held by the storage type of this quantized type."); 6766d4090dSAlex Zinenko quantizedType.def_property_readonly( 6866d4090dSAlex Zinenko "storage_type_max", 6966d4090dSAlex Zinenko [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); }, 7066d4090dSAlex Zinenko "The maximum value held by the storage type of this quantized type."); 7166d4090dSAlex Zinenko quantizedType.def_property_readonly( 7266d4090dSAlex Zinenko "storage_type_integral_width", 7366d4090dSAlex Zinenko [](MlirType type) { 7466d4090dSAlex Zinenko return mlirQuantizedTypeGetStorageTypeIntegralWidth(type); 7566d4090dSAlex Zinenko }, 7666d4090dSAlex Zinenko "The bitwidth of the storage type of this quantized type."); 7766d4090dSAlex Zinenko quantizedType.def( 7866d4090dSAlex Zinenko "is_compatible_expressed_type", 7966d4090dSAlex Zinenko [](MlirType type, MlirType candidate) { 8066d4090dSAlex Zinenko return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate); 8166d4090dSAlex Zinenko }, 8266d4090dSAlex Zinenko "Checks whether the candidate type can be expressed by this quantized " 8366d4090dSAlex Zinenko "type.", 84*5cd42747SPeter Hawkins nb::arg("candidate")); 8566d4090dSAlex Zinenko quantizedType.def_property_readonly( 8666d4090dSAlex Zinenko "quantized_element_type", 8766d4090dSAlex Zinenko [](MlirType type) { 8866d4090dSAlex Zinenko return mlirQuantizedTypeGetQuantizedElementType(type); 8966d4090dSAlex Zinenko }, 9066d4090dSAlex Zinenko "Element type of this quantized type expressed as quantized type."); 9166d4090dSAlex Zinenko quantizedType.def( 9266d4090dSAlex Zinenko "cast_from_storage_type", 9366d4090dSAlex Zinenko [](MlirType type, MlirType candidate) { 9466d4090dSAlex Zinenko MlirType castResult = 9566d4090dSAlex Zinenko mlirQuantizedTypeCastFromStorageType(type, candidate); 9666d4090dSAlex Zinenko if (!mlirTypeIsNull(castResult)) 9766d4090dSAlex Zinenko return castResult; 98*5cd42747SPeter Hawkins throw nb::type_error("Invalid cast."); 9966d4090dSAlex Zinenko }, 10066d4090dSAlex Zinenko "Casts from a type based on the storage type of this quantized type to a " 10166d4090dSAlex Zinenko "corresponding type based on the quantized type. Raises TypeError if the " 10266d4090dSAlex Zinenko "cast is not valid.", 103*5cd42747SPeter Hawkins nb::arg("candidate")); 10466d4090dSAlex Zinenko quantizedType.def_staticmethod( 10566d4090dSAlex Zinenko "cast_to_storage_type", 10666d4090dSAlex Zinenko [](MlirType type) { 10766d4090dSAlex Zinenko MlirType castResult = mlirQuantizedTypeCastToStorageType(type); 10866d4090dSAlex Zinenko if (!mlirTypeIsNull(castResult)) 10966d4090dSAlex Zinenko return castResult; 110*5cd42747SPeter Hawkins throw nb::type_error("Invalid cast."); 11166d4090dSAlex Zinenko }, 11266d4090dSAlex Zinenko "Casts from a type based on a quantized type to a corresponding type " 11366d4090dSAlex Zinenko "based on the storage type of this quantized type. Raises TypeError if " 11466d4090dSAlex Zinenko "the cast is not valid.", 115*5cd42747SPeter Hawkins nb::arg("type")); 11666d4090dSAlex Zinenko quantizedType.def( 11766d4090dSAlex Zinenko "cast_from_expressed_type", 11866d4090dSAlex Zinenko [](MlirType type, MlirType candidate) { 11966d4090dSAlex Zinenko MlirType castResult = 12066d4090dSAlex Zinenko mlirQuantizedTypeCastFromExpressedType(type, candidate); 12166d4090dSAlex Zinenko if (!mlirTypeIsNull(castResult)) 12266d4090dSAlex Zinenko return castResult; 123*5cd42747SPeter Hawkins throw nb::type_error("Invalid cast."); 12466d4090dSAlex Zinenko }, 12566d4090dSAlex Zinenko "Casts from a type based on the expressed type of this quantized type to " 12666d4090dSAlex Zinenko "a corresponding type based on the quantized type. Raises TypeError if " 12766d4090dSAlex Zinenko "the cast is not valid.", 128*5cd42747SPeter Hawkins nb::arg("candidate")); 12966d4090dSAlex Zinenko quantizedType.def_staticmethod( 13066d4090dSAlex Zinenko "cast_to_expressed_type", 13166d4090dSAlex Zinenko [](MlirType type) { 13266d4090dSAlex Zinenko MlirType castResult = mlirQuantizedTypeCastToExpressedType(type); 13366d4090dSAlex Zinenko if (!mlirTypeIsNull(castResult)) 13466d4090dSAlex Zinenko return castResult; 135*5cd42747SPeter Hawkins throw nb::type_error("Invalid cast."); 13666d4090dSAlex Zinenko }, 13766d4090dSAlex Zinenko "Casts from a type based on a quantized type to a corresponding type " 13866d4090dSAlex Zinenko "based on the expressed type of this quantized type. Raises TypeError if " 13966d4090dSAlex Zinenko "the cast is not valid.", 140*5cd42747SPeter Hawkins nb::arg("type")); 14166d4090dSAlex Zinenko quantizedType.def( 14266d4090dSAlex Zinenko "cast_expressed_to_storage_type", 14366d4090dSAlex Zinenko [](MlirType type, MlirType candidate) { 14466d4090dSAlex Zinenko MlirType castResult = 14566d4090dSAlex Zinenko mlirQuantizedTypeCastExpressedToStorageType(type, candidate); 14666d4090dSAlex Zinenko if (!mlirTypeIsNull(castResult)) 14766d4090dSAlex Zinenko return castResult; 148*5cd42747SPeter Hawkins throw nb::type_error("Invalid cast."); 14966d4090dSAlex Zinenko }, 15066d4090dSAlex Zinenko "Casts from a type based on the expressed type of this quantized type to " 15166d4090dSAlex Zinenko "a corresponding type based on the storage type. Raises TypeError if the " 15266d4090dSAlex Zinenko "cast is not valid.", 153*5cd42747SPeter Hawkins nb::arg("candidate")); 15466d4090dSAlex Zinenko 15566d4090dSAlex Zinenko quantizedType.get_class().attr("FLAG_SIGNED") = 15666d4090dSAlex Zinenko mlirQuantizedTypeGetSignedFlag(); 15766d4090dSAlex Zinenko 15866d4090dSAlex Zinenko //===-------------------------------------------------------------------===// 15966d4090dSAlex Zinenko // AnyQuantizedType 16066d4090dSAlex Zinenko //===-------------------------------------------------------------------===// 16166d4090dSAlex Zinenko 16266d4090dSAlex Zinenko auto anyQuantizedType = 16366d4090dSAlex Zinenko mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType, 16466d4090dSAlex Zinenko quantizedType.get_class()); 16566d4090dSAlex Zinenko anyQuantizedType.def_classmethod( 16666d4090dSAlex Zinenko "get", 167*5cd42747SPeter Hawkins [](nb::object cls, unsigned flags, MlirType storageType, 16866d4090dSAlex Zinenko MlirType expressedType, int64_t storageTypeMin, 16966d4090dSAlex Zinenko int64_t storageTypeMax) { 17066d4090dSAlex Zinenko return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType, 17166d4090dSAlex Zinenko storageTypeMin, storageTypeMax)); 17266d4090dSAlex Zinenko }, 17366d4090dSAlex Zinenko "Gets an instance of AnyQuantizedType in the same context as the " 17466d4090dSAlex Zinenko "provided storage type.", 175*5cd42747SPeter Hawkins nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), 176*5cd42747SPeter Hawkins nb::arg("expressed_type"), nb::arg("storage_type_min"), 177*5cd42747SPeter Hawkins nb::arg("storage_type_max")); 17866d4090dSAlex Zinenko 17966d4090dSAlex Zinenko //===-------------------------------------------------------------------===// 18066d4090dSAlex Zinenko // UniformQuantizedType 18166d4090dSAlex Zinenko //===-------------------------------------------------------------------===// 18266d4090dSAlex Zinenko 18366d4090dSAlex Zinenko auto uniformQuantizedType = mlir_type_subclass( 18466d4090dSAlex Zinenko m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType, 18566d4090dSAlex Zinenko quantizedType.get_class()); 18666d4090dSAlex Zinenko uniformQuantizedType.def_classmethod( 18766d4090dSAlex Zinenko "get", 188*5cd42747SPeter Hawkins [](nb::object cls, unsigned flags, MlirType storageType, 18966d4090dSAlex Zinenko MlirType expressedType, double scale, int64_t zeroPoint, 19066d4090dSAlex Zinenko int64_t storageTypeMin, int64_t storageTypeMax) { 19166d4090dSAlex Zinenko return cls(mlirUniformQuantizedTypeGet(flags, storageType, 19266d4090dSAlex Zinenko expressedType, scale, zeroPoint, 19366d4090dSAlex Zinenko storageTypeMin, storageTypeMax)); 19466d4090dSAlex Zinenko }, 19566d4090dSAlex Zinenko "Gets an instance of UniformQuantizedType in the same context as the " 19666d4090dSAlex Zinenko "provided storage type.", 197*5cd42747SPeter Hawkins nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), 198*5cd42747SPeter Hawkins nb::arg("expressed_type"), nb::arg("scale"), nb::arg("zero_point"), 199*5cd42747SPeter Hawkins nb::arg("storage_type_min"), nb::arg("storage_type_max")); 20066d4090dSAlex Zinenko uniformQuantizedType.def_property_readonly( 20166d4090dSAlex Zinenko "scale", 20266d4090dSAlex Zinenko [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); }, 20366d4090dSAlex Zinenko "The scale designates the difference between the real values " 20466d4090dSAlex Zinenko "corresponding to consecutive quantized values differing by 1."); 20566d4090dSAlex Zinenko uniformQuantizedType.def_property_readonly( 20666d4090dSAlex Zinenko "zero_point", 20766d4090dSAlex Zinenko [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); }, 20866d4090dSAlex Zinenko "The storage value corresponding to the real value 0 in the affine " 20966d4090dSAlex Zinenko "equation."); 21066d4090dSAlex Zinenko uniformQuantizedType.def_property_readonly( 21166d4090dSAlex Zinenko "is_fixed_point", 21266d4090dSAlex Zinenko [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); }, 21366d4090dSAlex Zinenko "Fixed point values are real numbers divided by a scale."); 21466d4090dSAlex Zinenko 21566d4090dSAlex Zinenko //===-------------------------------------------------------------------===// 21666d4090dSAlex Zinenko // UniformQuantizedPerAxisType 21766d4090dSAlex Zinenko //===-------------------------------------------------------------------===// 21866d4090dSAlex Zinenko auto uniformQuantizedPerAxisType = mlir_type_subclass( 21966d4090dSAlex Zinenko m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType, 22066d4090dSAlex Zinenko quantizedType.get_class()); 22166d4090dSAlex Zinenko uniformQuantizedPerAxisType.def_classmethod( 22266d4090dSAlex Zinenko "get", 223*5cd42747SPeter Hawkins [](nb::object cls, unsigned flags, MlirType storageType, 22466d4090dSAlex Zinenko MlirType expressedType, std::vector<double> scales, 22566d4090dSAlex Zinenko std::vector<int64_t> zeroPoints, int32_t quantizedDimension, 22666d4090dSAlex Zinenko int64_t storageTypeMin, int64_t storageTypeMax) { 22766d4090dSAlex Zinenko if (scales.size() != zeroPoints.size()) 228*5cd42747SPeter Hawkins throw nb::value_error( 22966d4090dSAlex Zinenko "Mismatching number of scales and zero points."); 23066d4090dSAlex Zinenko auto nDims = static_cast<intptr_t>(scales.size()); 23166d4090dSAlex Zinenko return cls(mlirUniformQuantizedPerAxisTypeGet( 23266d4090dSAlex Zinenko flags, storageType, expressedType, nDims, scales.data(), 23366d4090dSAlex Zinenko zeroPoints.data(), quantizedDimension, storageTypeMin, 23466d4090dSAlex Zinenko storageTypeMax)); 23566d4090dSAlex Zinenko }, 23666d4090dSAlex Zinenko "Gets an instance of UniformQuantizedPerAxisType in the same context as " 23766d4090dSAlex Zinenko "the provided storage type.", 238*5cd42747SPeter Hawkins nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), 239*5cd42747SPeter Hawkins nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"), 240*5cd42747SPeter Hawkins nb::arg("quantized_dimension"), nb::arg("storage_type_min"), 241*5cd42747SPeter Hawkins nb::arg("storage_type_max")); 24266d4090dSAlex Zinenko uniformQuantizedPerAxisType.def_property_readonly( 24366d4090dSAlex Zinenko "scales", 24466d4090dSAlex Zinenko [](MlirType type) { 24566d4090dSAlex Zinenko intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); 24666d4090dSAlex Zinenko std::vector<double> scales; 24766d4090dSAlex Zinenko scales.reserve(nDim); 24866d4090dSAlex Zinenko for (intptr_t i = 0; i < nDim; ++i) { 24966d4090dSAlex Zinenko double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i); 25066d4090dSAlex Zinenko scales.push_back(scale); 25166d4090dSAlex Zinenko } 25247ef5c4bSannuasd return scales; 25366d4090dSAlex Zinenko }, 25466d4090dSAlex Zinenko "The scales designate the difference between the real values " 25566d4090dSAlex Zinenko "corresponding to consecutive quantized values differing by 1. The ith " 25666d4090dSAlex Zinenko "scale corresponds to the ith slice in the quantized_dimension."); 25766d4090dSAlex Zinenko uniformQuantizedPerAxisType.def_property_readonly( 25866d4090dSAlex Zinenko "zero_points", 25966d4090dSAlex Zinenko [](MlirType type) { 26066d4090dSAlex Zinenko intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); 26166d4090dSAlex Zinenko std::vector<int64_t> zeroPoints; 26266d4090dSAlex Zinenko zeroPoints.reserve(nDim); 26366d4090dSAlex Zinenko for (intptr_t i = 0; i < nDim; ++i) { 26466d4090dSAlex Zinenko int64_t zeroPoint = 26566d4090dSAlex Zinenko mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i); 26666d4090dSAlex Zinenko zeroPoints.push_back(zeroPoint); 26766d4090dSAlex Zinenko } 26847ef5c4bSannuasd return zeroPoints; 26966d4090dSAlex Zinenko }, 27066d4090dSAlex Zinenko "the storage values corresponding to the real value 0 in the affine " 27166d4090dSAlex Zinenko "equation. The ith zero point corresponds to the ith slice in the " 27266d4090dSAlex Zinenko "quantized_dimension."); 27366d4090dSAlex Zinenko uniformQuantizedPerAxisType.def_property_readonly( 27466d4090dSAlex Zinenko "quantized_dimension", 27566d4090dSAlex Zinenko [](MlirType type) { 27666d4090dSAlex Zinenko return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type); 27766d4090dSAlex Zinenko }, 27866d4090dSAlex Zinenko "Specifies the dimension of the shape that the scales and zero points " 27966d4090dSAlex Zinenko "correspond to."); 28066d4090dSAlex Zinenko uniformQuantizedPerAxisType.def_property_readonly( 28166d4090dSAlex Zinenko "is_fixed_point", 28266d4090dSAlex Zinenko [](MlirType type) { 28366d4090dSAlex Zinenko return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type); 28466d4090dSAlex Zinenko }, 28566d4090dSAlex Zinenko "Fixed point values are real numbers divided by a scale."); 28666d4090dSAlex Zinenko 28766d4090dSAlex Zinenko //===-------------------------------------------------------------------===// 28866d4090dSAlex Zinenko // CalibratedQuantizedType 28966d4090dSAlex Zinenko //===-------------------------------------------------------------------===// 29066d4090dSAlex Zinenko 29166d4090dSAlex Zinenko auto calibratedQuantizedType = mlir_type_subclass( 29266d4090dSAlex Zinenko m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType, 29366d4090dSAlex Zinenko quantizedType.get_class()); 29466d4090dSAlex Zinenko calibratedQuantizedType.def_classmethod( 29566d4090dSAlex Zinenko "get", 296*5cd42747SPeter Hawkins [](nb::object cls, MlirType expressedType, double min, double max) { 29766d4090dSAlex Zinenko return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max)); 29866d4090dSAlex Zinenko }, 29966d4090dSAlex Zinenko "Gets an instance of CalibratedQuantizedType in the same context as the " 30066d4090dSAlex Zinenko "provided expressed type.", 301*5cd42747SPeter Hawkins nb::arg("cls"), nb::arg("expressed_type"), nb::arg("min"), 302*5cd42747SPeter Hawkins nb::arg("max")); 30366d4090dSAlex Zinenko calibratedQuantizedType.def_property_readonly("min", [](MlirType type) { 30466d4090dSAlex Zinenko return mlirCalibratedQuantizedTypeGetMin(type); 30566d4090dSAlex Zinenko }); 30666d4090dSAlex Zinenko calibratedQuantizedType.def_property_readonly("max", [](MlirType type) { 30766d4090dSAlex Zinenko return mlirCalibratedQuantizedTypeGetMax(type); 30866d4090dSAlex Zinenko }); 30966d4090dSAlex Zinenko } 31095ddbed9SAlex Zinenko 311*5cd42747SPeter Hawkins NB_MODULE(_mlirDialectsQuant, m) { 31295ddbed9SAlex Zinenko m.doc() = "MLIR Quantization dialect"; 31395ddbed9SAlex Zinenko 31495ddbed9SAlex Zinenko populateDialectQuantSubmodule(m); 31595ddbed9SAlex Zinenko } 316