xref: /llvm-project/mlir/lib/CAPI/Dialect/Quant.cpp (revision 852b6486246141e44cc9f126f542a2ae0d73b3d6)
1 //===- Quant.cpp - C Interface for Quant dialect --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/Quant.h"
10 #include "mlir/CAPI/Registration.h"
11 #include "mlir/Dialect/Quant/IR/Quant.h"
12 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
13 
14 using namespace mlir;
15 
16 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantDialect)
17 
18 //===---------------------------------------------------------------------===//
19 // QuantizedType
20 //===---------------------------------------------------------------------===//
21 
22 bool mlirTypeIsAQuantizedType(MlirType type) {
23   return isa<quant::QuantizedType>(unwrap(type));
24 }
25 
26 unsigned mlirQuantizedTypeGetSignedFlag() {
27   return quant::QuantizationFlags::Signed;
28 }
29 
30 int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned,
31                                                      unsigned integralWidth) {
32   return quant::QuantizedType::getDefaultMinimumForInteger(isSigned,
33                                                            integralWidth);
34 }
35 
36 int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned,
37                                                      unsigned integralWidth) {
38   return quant::QuantizedType::getDefaultMaximumForInteger(isSigned,
39                                                            integralWidth);
40 }
41 
42 MlirType mlirQuantizedTypeGetExpressedType(MlirType type) {
43   return wrap(cast<quant::QuantizedType>(unwrap(type)).getExpressedType());
44 }
45 
46 unsigned mlirQuantizedTypeGetFlags(MlirType type) {
47   return cast<quant::QuantizedType>(unwrap(type)).getFlags();
48 }
49 
50 bool mlirQuantizedTypeIsSigned(MlirType type) {
51   return cast<quant::QuantizedType>(unwrap(type)).isSigned();
52 }
53 
54 MlirType mlirQuantizedTypeGetStorageType(MlirType type) {
55   return wrap(cast<quant::QuantizedType>(unwrap(type)).getStorageType());
56 }
57 
58 int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) {
59   return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeMin();
60 }
61 
62 int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) {
63   return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeMax();
64 }
65 
66 unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) {
67   return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeIntegralWidth();
68 }
69 
70 bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type,
71                                                 MlirType candidate) {
72   return cast<quant::QuantizedType>(unwrap(type))
73       .isCompatibleExpressedType(unwrap(candidate));
74 }
75 
76 MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) {
77   return wrap(quant::QuantizedType::getQuantizedElementType(unwrap(type)));
78 }
79 
80 MlirType mlirQuantizedTypeCastFromStorageType(MlirType type,
81                                               MlirType candidate) {
82   return wrap(cast<quant::QuantizedType>(unwrap(type))
83                   .castFromStorageType(unwrap(candidate)));
84 }
85 
86 MlirType mlirQuantizedTypeCastToStorageType(MlirType type) {
87   return wrap(quant::QuantizedType::castToStorageType(
88       cast<quant::QuantizedType>(unwrap(type))));
89 }
90 
91 MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type,
92                                                 MlirType candidate) {
93   return wrap(cast<quant::QuantizedType>(unwrap(type))
94                   .castFromExpressedType(unwrap(candidate)));
95 }
96 
97 MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) {
98   return wrap(quant::QuantizedType::castToExpressedType(unwrap(type)));
99 }
100 
101 MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type,
102                                                      MlirType candidate) {
103   return wrap(cast<quant::QuantizedType>(unwrap(type))
104                   .castExpressedToStorageType(unwrap(candidate)));
105 }
106 
107 //===---------------------------------------------------------------------===//
108 // AnyQuantizedType
109 //===---------------------------------------------------------------------===//
110 
111 bool mlirTypeIsAAnyQuantizedType(MlirType type) {
112   return isa<quant::AnyQuantizedType>(unwrap(type));
113 }
114 
115 MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType,
116                                  MlirType expressedType, int64_t storageTypeMin,
117                                  int64_t storageTypeMax) {
118   return wrap(quant::AnyQuantizedType::get(flags, unwrap(storageType),
119                                            unwrap(expressedType),
120                                            storageTypeMin, storageTypeMax));
121 }
122 
123 //===---------------------------------------------------------------------===//
124 // UniformQuantizedType
125 //===---------------------------------------------------------------------===//
126 
127 bool mlirTypeIsAUniformQuantizedType(MlirType type) {
128   return isa<quant::UniformQuantizedType>(unwrap(type));
129 }
130 
131 MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType,
132                                      MlirType expressedType, double scale,
133                                      int64_t zeroPoint, int64_t storageTypeMin,
134                                      int64_t storageTypeMax) {
135   return wrap(quant::UniformQuantizedType::get(
136       flags, unwrap(storageType), unwrap(expressedType), scale, zeroPoint,
137       storageTypeMin, storageTypeMax));
138 }
139 
140 double mlirUniformQuantizedTypeGetScale(MlirType type) {
141   return cast<quant::UniformQuantizedType>(unwrap(type)).getScale();
142 }
143 
144 int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) {
145   return cast<quant::UniformQuantizedType>(unwrap(type)).getZeroPoint();
146 }
147 
148 bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) {
149   return cast<quant::UniformQuantizedType>(unwrap(type)).isFixedPoint();
150 }
151 
152 //===---------------------------------------------------------------------===//
153 // UniformQuantizedPerAxisType
154 //===---------------------------------------------------------------------===//
155 
156 bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) {
157   return isa<quant::UniformQuantizedPerAxisType>(unwrap(type));
158 }
159 
160 MlirType mlirUniformQuantizedPerAxisTypeGet(
161     unsigned flags, MlirType storageType, MlirType expressedType,
162     intptr_t nDims, double *scales, int64_t *zeroPoints,
163     int32_t quantizedDimension, int64_t storageTypeMin,
164     int64_t storageTypeMax) {
165   return wrap(quant::UniformQuantizedPerAxisType::get(
166       flags, unwrap(storageType), unwrap(expressedType),
167       llvm::ArrayRef(scales, nDims), llvm::ArrayRef(zeroPoints, nDims),
168       quantizedDimension, storageTypeMin, storageTypeMax));
169 }
170 
171 intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) {
172   return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
173       .getScales()
174       .size();
175 }
176 
177 double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) {
178   return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
179       .getScales()[pos];
180 }
181 
182 int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type,
183                                                     intptr_t pos) {
184   return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
185       .getZeroPoints()[pos];
186 }
187 
188 int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) {
189   return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
190       .getQuantizedDimension();
191 }
192 
193 bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) {
194   return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)).isFixedPoint();
195 }
196 
197 //===---------------------------------------------------------------------===//
198 // CalibratedQuantizedType
199 //===---------------------------------------------------------------------===//
200 
201 bool mlirTypeIsACalibratedQuantizedType(MlirType type) {
202   return isa<quant::CalibratedQuantizedType>(unwrap(type));
203 }
204 
205 MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min,
206                                         double max) {
207   return wrap(
208       quant::CalibratedQuantizedType::get(unwrap(expressedType), min, max));
209 }
210 
211 double mlirCalibratedQuantizedTypeGetMin(MlirType type) {
212   return cast<quant::CalibratedQuantizedType>(unwrap(type)).getMin();
213 }
214 
215 double mlirCalibratedQuantizedTypeGetMax(MlirType type) {
216   return cast<quant::CalibratedQuantizedType>(unwrap(type)).getMax();
217 }
218