xref: /llvm-project/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp (revision 852b6486246141e44cc9f126f542a2ae0d73b3d6)
1363dd3f3SRob Suderman //===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===//
2363dd3f3SRob Suderman //
3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6363dd3f3SRob Suderman //
7363dd3f3SRob Suderman //===----------------------------------------------------------------------===//
8363dd3f3SRob Suderman 
9*852b6486SRafael Ubal #include "mlir/Dialect/Quant/IR/QuantTypes.h"
10*852b6486SRafael Ubal #include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
11363dd3f3SRob Suderman 
12363dd3f3SRob Suderman using namespace mlir;
13363dd3f3SRob Suderman using namespace mlir::quant;
14363dd3f3SRob Suderman 
15363dd3f3SRob Suderman static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
16363dd3f3SRob Suderman                                     bool isSigned, MLIRContext *ctx,
17363dd3f3SRob Suderman                                     Type &storageType, int64_t &qmin,
18363dd3f3SRob Suderman                                     int64_t &qmax) {
19363dd3f3SRob Suderman   // Hard-coded type mapping from TFLite.
20363dd3f3SRob Suderman   if (numBits <= 8) {
211b97cdf8SRiver Riddle     storageType = IntegerType::get(ctx, 8);
22363dd3f3SRob Suderman     if (isSigned) {
23363dd3f3SRob Suderman       qmin = -128;
24363dd3f3SRob Suderman       qmax = 127;
25363dd3f3SRob Suderman     } else {
26363dd3f3SRob Suderman       qmin = 0;
27363dd3f3SRob Suderman       qmax = 255;
28363dd3f3SRob Suderman     }
29363dd3f3SRob Suderman   } else if (numBits <= 16) {
301b97cdf8SRiver Riddle     storageType = IntegerType::get(ctx, 16);
31363dd3f3SRob Suderman     if (isSigned) {
32363dd3f3SRob Suderman       qmin = -32768;
33363dd3f3SRob Suderman       qmax = 32767;
34363dd3f3SRob Suderman     } else {
35363dd3f3SRob Suderman       qmin = 0;
36363dd3f3SRob Suderman       qmax = 65535;
37363dd3f3SRob Suderman     }
38b578c92aSFeng Liu   } else if (numBits <= 32) {
391b97cdf8SRiver Riddle     storageType = IntegerType::get(ctx, 32);
40b578c92aSFeng Liu     if (isSigned) {
41b578c92aSFeng Liu       qmin = std::numeric_limits<int32_t>::min();
42b578c92aSFeng Liu       qmax = std::numeric_limits<int32_t>::max();
43b578c92aSFeng Liu     } else {
44b578c92aSFeng Liu       qmin = std::numeric_limits<uint32_t>::min();
45b578c92aSFeng Liu       qmax = std::numeric_limits<uint32_t>::max();
46b578c92aSFeng Liu     }
47363dd3f3SRob Suderman   } else {
48363dd3f3SRob Suderman     return true;
49363dd3f3SRob Suderman   }
50363dd3f3SRob Suderman 
51363dd3f3SRob Suderman   // Handle narrowRange.
52363dd3f3SRob Suderman   if (narrowRange) {
53363dd3f3SRob Suderman     qmin += 1;
54363dd3f3SRob Suderman   }
55363dd3f3SRob Suderman   return false;
56363dd3f3SRob Suderman }
57363dd3f3SRob Suderman 
58363dd3f3SRob Suderman // This is a specific implementation of nudging:
59363dd3f3SRob Suderman // If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted
60363dd3f3SRob Suderman // to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero
61363dd3f3SRob Suderman // point is derived from the shifted range, and the scale isn't changed. As
62363dd3f3SRob Suderman // a consequence some values, which are supposed in the original [rmin, rmax]
63363dd3f3SRob Suderman // range will be outside the shifted range and be clamped during quantization.
649db53a18SRiver Riddle // TODO: we should nudge the scale as well, but that requires the
65363dd3f3SRob Suderman // fake quant op used in the training to use the nudged scale as well.
66363dd3f3SRob Suderman static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin,
67363dd3f3SRob Suderman                                        double rmax, double &scale,
68363dd3f3SRob Suderman                                        int64_t &nudgedZeroPoint) {
69363dd3f3SRob Suderman   // Determine the scale.
70363dd3f3SRob Suderman   const double qminDouble = qmin;
71363dd3f3SRob Suderman   const double qmaxDouble = qmax;
72363dd3f3SRob Suderman   scale = (rmax - rmin) / (qmaxDouble - qminDouble);
73363dd3f3SRob Suderman 
74363dd3f3SRob Suderman   // Zero point computation.
75363dd3f3SRob Suderman   // In float, solve the affine equation for any known pair
76363dd3f3SRob Suderman   // (real value, corresponding quantized value), of which, two such pairs
77363dd3f3SRob Suderman   // are known: (rmin, qmin), (rmax, qmax).
78363dd3f3SRob Suderman   // The arithmetic error on the zero point computed from either pair will be
79363dd3f3SRob Suderman   // roughly machine_epsilon * (sum of absolute values of terms).
80363dd3f3SRob Suderman   // Use the variant that adds the smaller error.
81363dd3f3SRob Suderman   const double zeroPointFromMin = qminDouble - rmin / scale;
82363dd3f3SRob Suderman   const double zeroPointFromMinError =
83363dd3f3SRob Suderman       std::abs(qminDouble) + std::abs(rmin / scale);
84363dd3f3SRob Suderman   const double zeroPointFromMax = qmaxDouble - rmax / scale;
85363dd3f3SRob Suderman   const double zeroPointFromMaxError =
86363dd3f3SRob Suderman       std::abs(qmaxDouble) + std::abs(rmax / scale);
87363dd3f3SRob Suderman 
88363dd3f3SRob Suderman   const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError)
89363dd3f3SRob Suderman                                      ? zeroPointFromMin
90363dd3f3SRob Suderman                                      : zeroPointFromMax;
91363dd3f3SRob Suderman 
92363dd3f3SRob Suderman   // Now nudge the zero point to be an integer.
93363dd3f3SRob Suderman   nudgedZeroPoint = 0;
94363dd3f3SRob Suderman   if (zeroPointDouble < qminDouble) {
95363dd3f3SRob Suderman     nudgedZeroPoint = qmin;
96363dd3f3SRob Suderman   } else if (zeroPointDouble > qmaxDouble) {
97363dd3f3SRob Suderman     nudgedZeroPoint = qmax;
98363dd3f3SRob Suderman   } else {
99363dd3f3SRob Suderman     nudgedZeroPoint = round(zeroPointDouble);
100363dd3f3SRob Suderman   }
101363dd3f3SRob Suderman 
102363dd3f3SRob Suderman   // By construction, the nudged zero point should always be in range.
103363dd3f3SRob Suderman   assert(nudgedZeroPoint >= qmin);
104363dd3f3SRob Suderman   assert(nudgedZeroPoint <= qmax);
105363dd3f3SRob Suderman }
106363dd3f3SRob Suderman 
107363dd3f3SRob Suderman UniformQuantizedType
108363dd3f3SRob Suderman mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
109363dd3f3SRob Suderman                                   double rmax, bool narrowRange,
110363dd3f3SRob Suderman                                   Type expressedType, bool isSigned) {
111363dd3f3SRob Suderman   MLIRContext *ctx = expressedType.getContext();
112363dd3f3SRob Suderman   unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
113363dd3f3SRob Suderman   Type storageType;
114363dd3f3SRob Suderman   int64_t qmin;
115363dd3f3SRob Suderman   int64_t qmax;
116363dd3f3SRob Suderman   if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
117363dd3f3SRob Suderman                               qmin, qmax)) {
118363dd3f3SRob Suderman     return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
119363dd3f3SRob Suderman             nullptr);
120363dd3f3SRob Suderman   }
121363dd3f3SRob Suderman 
122363dd3f3SRob Suderman   // Special case where min/max is close enough. The tensor contents are all
123363dd3f3SRob Suderman   // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
124363dd3f3SRob Suderman   // points and dequantized to 0.0.
125363dd3f3SRob Suderman   if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
12606e25d56SRiver Riddle     return UniformQuantizedType::getChecked(
12706e25d56SRiver Riddle         loc, flags, storageType, expressedType, 1.0, qmin, qmin, qmax);
128363dd3f3SRob Suderman   }
129363dd3f3SRob Suderman 
130363dd3f3SRob Suderman   double scale;
131363dd3f3SRob Suderman   int64_t nudgedZeroPoint;
132363dd3f3SRob Suderman   getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
133363dd3f3SRob Suderman 
13406e25d56SRiver Riddle   return UniformQuantizedType::getChecked(loc, flags, storageType,
13506e25d56SRiver Riddle                                           expressedType, scale, nudgedZeroPoint,
13606e25d56SRiver Riddle                                           qmin, qmax);
137363dd3f3SRob Suderman }
138363dd3f3SRob Suderman 
139363dd3f3SRob Suderman UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType(
140363dd3f3SRob Suderman     Location loc, unsigned numBits, int32_t quantizedDimension,
141363dd3f3SRob Suderman     ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange,
142363dd3f3SRob Suderman     Type expressedType, bool isSigned) {
14302b6fb21SMehdi Amini   size_t axisSize = rmins.size();
14402b6fb21SMehdi Amini   if (axisSize != rmaxs.size()) {
145363dd3f3SRob Suderman     return (emitError(loc, "mismatched per-axis min and max size: ")
14602b6fb21SMehdi Amini                 << axisSize << " vs. " << rmaxs.size(),
147363dd3f3SRob Suderman             nullptr);
148363dd3f3SRob Suderman   }
149363dd3f3SRob Suderman 
150363dd3f3SRob Suderman   MLIRContext *ctx = expressedType.getContext();
151363dd3f3SRob Suderman   Type storageType;
152363dd3f3SRob Suderman   int64_t qmin;
153363dd3f3SRob Suderman   int64_t qmax;
154363dd3f3SRob Suderman   if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
155363dd3f3SRob Suderman                               qmin, qmax)) {
156363dd3f3SRob Suderman     return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
157363dd3f3SRob Suderman             nullptr);
158363dd3f3SRob Suderman   }
159363dd3f3SRob Suderman 
160363dd3f3SRob Suderman   SmallVector<double, 4> scales;
161363dd3f3SRob Suderman   SmallVector<int64_t, 4> zeroPoints;
16202b6fb21SMehdi Amini   scales.reserve(axisSize);
16302b6fb21SMehdi Amini   zeroPoints.reserve(axisSize);
16402b6fb21SMehdi Amini   for (size_t axis = 0; axis != axisSize; ++axis) {
165363dd3f3SRob Suderman     double rmin = rmins[axis];
166363dd3f3SRob Suderman     double rmax = rmaxs[axis];
167363dd3f3SRob Suderman     if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
168363dd3f3SRob Suderman       scales.push_back(1.0);
169363dd3f3SRob Suderman       zeroPoints.push_back(qmin);
170363dd3f3SRob Suderman       continue;
171363dd3f3SRob Suderman     }
172363dd3f3SRob Suderman 
173363dd3f3SRob Suderman     double scale;
174363dd3f3SRob Suderman     int64_t nudgedZeroPoint;
175363dd3f3SRob Suderman     getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
176363dd3f3SRob Suderman     scales.push_back(scale);
177363dd3f3SRob Suderman     zeroPoints.push_back(nudgedZeroPoint);
178363dd3f3SRob Suderman   }
179363dd3f3SRob Suderman 
180363dd3f3SRob Suderman   unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
181363dd3f3SRob Suderman   return UniformQuantizedPerAxisType::getChecked(
18206e25d56SRiver Riddle       loc, flags, storageType, expressedType, scales, zeroPoints,
18306e25d56SRiver Riddle       quantizedDimension, qmin, qmax);
184363dd3f3SRob Suderman }
185