xref: /llvm-project/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (revision 90a5744bebffafb88abf2343a1a70a37e12abef4)
1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TypeDetail.h"
10 #include "mlir/Dialect/Quant/IR/Quant.h"
11 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
12 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/MathExtras.h"
18 
19 using namespace mlir;
20 using namespace mlir::quant;
21 using namespace mlir::quant::detail;
22 
23 namespace {
24 
25 // Return the minimum scale representable in a given float type
26 double getMinScale(Type expressedType) {
27   auto floatType = cast<FloatType>(expressedType);
28   return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
29 }
30 
31 // Return the maximum scale representable in a given float type
32 double getMaxScale(Type expressedType) {
33   auto floatType = cast<FloatType>(expressedType);
34   return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
35 }
36 
37 }  // namespace
38 
39 unsigned QuantizedType::getFlags() const {
40   return static_cast<ImplType *>(impl)->flags;
41 }
42 
43 bool QuantizedType::classof(Type type) {
44   return llvm::isa<QuantDialect>(type.getDialect());
45 }
46 
47 LogicalResult
48 QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
49                                 unsigned flags, Type storageType,
50                                 Type expressedType, int64_t storageTypeMin,
51                                 int64_t storageTypeMax) {
52   // Verify that the storage type is integral.
53   // This restriction may be lifted at some point in favor of using bf16
54   // or f16 as exact representations on hardware where that is advantageous.
55   auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
56   if (!intStorageType)
57     return emitError() << "storage type must be integral";
58   unsigned integralWidth = intStorageType.getWidth();
59 
60   // Verify storage width.
61   if (integralWidth == 0 || integralWidth > MaxStorageBits)
62     return emitError() << "illegal storage type size: " << integralWidth;
63 
64   // Verify storageTypeMin and storageTypeMax.
65   bool isSigned =
66       (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
67   int64_t defaultIntegerMin =
68       getDefaultMinimumForInteger(isSigned, integralWidth);
69   int64_t defaultIntegerMax =
70       getDefaultMaximumForInteger(isSigned, integralWidth);
71   if (storageTypeMax - storageTypeMin <= 0 ||
72       storageTypeMin < defaultIntegerMin ||
73       storageTypeMax > defaultIntegerMax) {
74     return emitError() << "illegal storage min and storage max: ("
75                        << storageTypeMin << ":" << storageTypeMax << ")";
76   }
77   return success();
78 }
79 
80 Type QuantizedType::getStorageType() const {
81   return static_cast<ImplType *>(impl)->storageType;
82 }
83 
84 int64_t QuantizedType::getStorageTypeMin() const {
85   return static_cast<ImplType *>(impl)->storageTypeMin;
86 }
87 
88 int64_t QuantizedType::getStorageTypeMax() const {
89   return static_cast<ImplType *>(impl)->storageTypeMax;
90 }
91 
92 bool QuantizedType::hasStorageTypeBounds() const {
93   unsigned int integralWidth = getStorageTypeIntegralWidth();
94   bool isSignedInteger = isSigned();
95   int64_t defaultIntegerMin =
96       getDefaultMinimumForInteger(isSignedInteger, integralWidth);
97   int64_t defaultIntegerMax =
98       getDefaultMaximumForInteger(isSignedInteger, integralWidth);
99   return defaultIntegerMin != getStorageTypeMin() ||
100          defaultIntegerMax != getStorageTypeMax();
101 }
102 
103 unsigned QuantizedType::getStorageTypeIntegralWidth() const {
104   // NOTE: If ever supporting non-integral storage types, some other scheme
105   // for determining the width will be needed.
106   return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
107 }
108 
109 Type QuantizedType::getExpressedType() const {
110   return static_cast<ImplType *>(impl)->expressedType;
111 }
112 
113 bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
114   if (llvm::isa<ShapedType>(candidateExpressedType)) {
115     return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
116            getExpressedType();
117   }
118   return candidateExpressedType == getExpressedType();
119 }
120 
121 QuantizedType
122 QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
123   if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
124     Type elementType =
125         llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
126     return llvm::dyn_cast<QuantizedType>(elementType);
127   }
128   return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType);
129 }
130 
131 Type QuantizedType::castFromStorageType(Type candidateType) {
132   if (candidateType == getStorageType()) {
133     // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
134     return *this;
135   }
136   if (llvm::isa<RankedTensorType>(candidateType)) {
137     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
138     return RankedTensorType::get(
139         llvm::cast<RankedTensorType>(candidateType).getShape(),
140         getStorageType());
141   }
142   if (llvm::isa<UnrankedTensorType>(candidateType)) {
143     // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
144     return UnrankedTensorType::get(getStorageType());
145   }
146   if (llvm::isa<VectorType>(candidateType)) {
147     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
148     return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(),
149                            getStorageType());
150   }
151 
152   return nullptr;
153 }
154 
155 Type QuantizedType::castToStorageType(Type quantizedType) {
156   if (llvm::isa<QuantizedType>(quantizedType)) {
157     // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
158     return llvm::cast<QuantizedType>(quantizedType).getStorageType();
159   }
160   if (llvm::isa<ShapedType>(quantizedType)) {
161     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
162     ShapedType sType = llvm::cast<ShapedType>(quantizedType);
163     if (!llvm::isa<QuantizedType>(sType.getElementType())) {
164       return nullptr;
165     }
166     Type storageType =
167         llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
168     if (llvm::isa<RankedTensorType>(quantizedType)) {
169       return RankedTensorType::get(sType.getShape(), storageType);
170     }
171     if (llvm::isa<UnrankedTensorType>(quantizedType)) {
172       return UnrankedTensorType::get(storageType);
173     }
174     if (llvm::isa<VectorType>(quantizedType)) {
175       return VectorType::get(sType.getShape(), storageType);
176     }
177   }
178 
179   return nullptr;
180 }
181 
182 Type QuantizedType::castFromExpressedType(Type candidateType) {
183   if (candidateType == getExpressedType()) {
184     // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
185     return *this;
186   }
187   if (llvm::isa<ShapedType>(candidateType)) {
188     ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
189     if (candidateShapedType.getElementType() != getExpressedType()) {
190       return nullptr;
191     }
192 
193     if (llvm::isa<RankedTensorType>(candidateType)) {
194       // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
195       return RankedTensorType::get(candidateShapedType.getShape(), *this);
196     }
197     if (llvm::isa<UnrankedTensorType>(candidateType)) {
198       // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
199       return UnrankedTensorType::get(*this);
200     }
201     if (llvm::isa<VectorType>(candidateType)) {
202       // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
203       return VectorType::get(candidateShapedType.getShape(), *this);
204     }
205   }
206 
207   return nullptr;
208 }
209 
210 Type QuantizedType::castToExpressedType(Type quantizedType) {
211   if (llvm::isa<QuantizedType>(quantizedType)) {
212     // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
213     return llvm::cast<QuantizedType>(quantizedType).getExpressedType();
214   }
215   if (llvm::isa<ShapedType>(quantizedType)) {
216     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
217     ShapedType sType = llvm::cast<ShapedType>(quantizedType);
218     if (!llvm::isa<QuantizedType>(sType.getElementType())) {
219       return nullptr;
220     }
221     Type expressedType =
222         llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
223     if (llvm::isa<RankedTensorType>(quantizedType)) {
224       return RankedTensorType::get(sType.getShape(), expressedType);
225     }
226     if (llvm::isa<UnrankedTensorType>(quantizedType)) {
227       return UnrankedTensorType::get(expressedType);
228     }
229     if (llvm::isa<VectorType>(quantizedType)) {
230       return VectorType::get(sType.getShape(), expressedType);
231     }
232   }
233 
234   return nullptr;
235 }
236 
237 Type QuantizedType::castExpressedToStorageType(Type candidateType) {
238   Type expressedQuantizedType = castFromExpressedType(candidateType);
239   if (!expressedQuantizedType) {
240     return nullptr;
241   }
242   return QuantizedType::castToStorageType(expressedQuantizedType);
243 }
244 
245 AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
246                                        Type expressedType,
247                                        int64_t storageTypeMin,
248                                        int64_t storageTypeMax) {
249   return Base::get(storageType.getContext(), flags, storageType, expressedType,
250                    storageTypeMin, storageTypeMax);
251 }
252 
253 AnyQuantizedType
254 AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
255                              unsigned flags, Type storageType,
256                              Type expressedType, int64_t storageTypeMin,
257                              int64_t storageTypeMax) {
258   return Base::getChecked(emitError, storageType.getContext(), flags,
259                           storageType, expressedType, storageTypeMin,
260                           storageTypeMax);
261 }
262 
263 LogicalResult
264 AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
265                                    unsigned flags, Type storageType,
266                                    Type expressedType, int64_t storageTypeMin,
267                                    int64_t storageTypeMax) {
268   if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
269                                              expressedType, storageTypeMin,
270                                              storageTypeMax))) {
271     return failure();
272   }
273 
274   // Verify that the expressed type is floating point.
275   // If this restriction is ever eliminated, the parser/printer must be
276   // extended.
277   if (expressedType && !llvm::isa<FloatType>(expressedType))
278     return emitError() << "expressed type must be floating point";
279 
280   return success();
281 }
282 
283 UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
284                                                Type expressedType, double scale,
285                                                int64_t zeroPoint,
286                                                int64_t storageTypeMin,
287                                                int64_t storageTypeMax) {
288   return Base::get(storageType.getContext(), flags, storageType, expressedType,
289                    scale, zeroPoint, storageTypeMin, storageTypeMax);
290 }
291 
292 UniformQuantizedType UniformQuantizedType::getChecked(
293     function_ref<InFlightDiagnostic()> emitError, unsigned flags,
294     Type storageType, Type expressedType, double scale, int64_t zeroPoint,
295     int64_t storageTypeMin, int64_t storageTypeMax) {
296   return Base::getChecked(emitError, storageType.getContext(), flags,
297                           storageType, expressedType, scale, zeroPoint,
298                           storageTypeMin, storageTypeMax);
299 }
300 
301 LogicalResult UniformQuantizedType::verifyInvariants(
302     function_ref<InFlightDiagnostic()> emitError, unsigned flags,
303     Type storageType, Type expressedType, double scale, int64_t zeroPoint,
304     int64_t storageTypeMin, int64_t storageTypeMax) {
305   if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
306                                              expressedType, storageTypeMin,
307                                              storageTypeMax))) {
308     return failure();
309   }
310 
311   // Uniform quantization requires fully expressed parameters, including
312   // expressed type.
313   if (!expressedType)
314     return emitError() << "uniform quantization requires expressed type";
315 
316   // Verify that the expressed type is floating point.
317   // If this restriction is ever eliminated, the parser/printer must be
318   // extended.
319   if (!llvm::isa<FloatType>(expressedType))
320     return emitError() << "expressed type must be floating point";
321 
322   // Verify scale.
323   double minScale = getMinScale(expressedType);
324   double maxScale = getMaxScale(expressedType);
325   if (scale < minScale || scale > maxScale)
326     return emitError() << "scale out of expressed type range [" << minScale
327                        << ", " << maxScale << "]";
328 
329   return success();
330 }
331 
332 double UniformQuantizedType::getScale() const { return getImpl()->scale; }
333 
334 int64_t UniformQuantizedType::getZeroPoint() const {
335   return getImpl()->zeroPoint;
336 }
337 
338 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
339     unsigned flags, Type storageType, Type expressedType,
340     ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
341     int32_t quantizedDimension, int64_t storageTypeMin,
342     int64_t storageTypeMax) {
343   return Base::get(storageType.getContext(), flags, storageType, expressedType,
344                    scales, zeroPoints, quantizedDimension, storageTypeMin,
345                    storageTypeMax);
346 }
347 
348 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
349     function_ref<InFlightDiagnostic()> emitError, unsigned flags,
350     Type storageType, Type expressedType, ArrayRef<double> scales,
351     ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
352     int64_t storageTypeMin, int64_t storageTypeMax) {
353   return Base::getChecked(emitError, storageType.getContext(), flags,
354                           storageType, expressedType, scales, zeroPoints,
355                           quantizedDimension, storageTypeMin, storageTypeMax);
356 }
357 
358 LogicalResult UniformQuantizedPerAxisType::verifyInvariants(
359     function_ref<InFlightDiagnostic()> emitError, unsigned flags,
360     Type storageType, Type expressedType, ArrayRef<double> scales,
361     ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
362     int64_t storageTypeMin, int64_t storageTypeMax) {
363   if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
364                                              expressedType, storageTypeMin,
365                                              storageTypeMax))) {
366     return failure();
367   }
368 
369   // Uniform quantization requires fully expressed parameters, including
370   // expressed type.
371   if (!expressedType)
372     return emitError() << "uniform quantization requires expressed type";
373 
374   // Verify that the expressed type is floating point.
375   // If this restriction is ever eliminated, the parser/printer must be
376   // extended.
377   if (!llvm::isa<FloatType>(expressedType))
378     return emitError() << "expressed type must be floating point";
379 
380   // Ensure that the number of scales and zeroPoints match.
381   if (scales.size() != zeroPoints.size())
382     return emitError() << "illegal number of scales and zeroPoints: "
383                        << scales.size() << ", " << zeroPoints.size();
384 
385   // Verify scale.
386   double minScale = getMinScale(expressedType);
387   double maxScale = getMaxScale(expressedType);
388   for (double scale : scales) {
389     if (scale < minScale || scale > maxScale)
390       return emitError() << "scale out of expressed type range [" << minScale
391                          << ", " << maxScale << "]";
392   }
393 
394   // Verify quantized dimension.
395   if (quantizedDimension < 0)
396     return emitError() << "illegal quantized dimension: " << quantizedDimension;
397 
398   return success();
399 }
400 
401 ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
402   return getImpl()->getScales();
403 }
404 
405 ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
406   return getImpl()->getZeroPoints();
407 }
408 
409 int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
410   return getImpl()->quantizedDimension;
411 }
412 
413 CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
414                                                      double min, double max) {
415   return Base::get(expressedType.getContext(), expressedType, min, max);
416 }
417 
418 CalibratedQuantizedType CalibratedQuantizedType::getChecked(
419     function_ref<InFlightDiagnostic()> emitError, Type expressedType,
420     double min, double max) {
421   return Base::getChecked(emitError, expressedType.getContext(), expressedType,
422                           min, max);
423 }
424 
425 LogicalResult CalibratedQuantizedType::verifyInvariants(
426     function_ref<InFlightDiagnostic()> emitError, Type expressedType,
427     double min, double max) {
428   // Verify that the expressed type is floating point.
429   // If this restriction is ever eliminated, the parser/printer must be
430   // extended.
431   if (!llvm::isa<FloatType>(expressedType))
432     return emitError() << "expressed type must be floating point";
433   if (max <= min)
434     return emitError() << "illegal min and max: (" << min << ":" << max << ")";
435 
436   return success();
437 }
438 
439 double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
440 
441 double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
442