1 //===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Quant/IR/QuantTypes.h" 10 #include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h" 11 12 using namespace mlir; 13 using namespace mlir::quant; 14 15 static bool getDefaultStorageParams(unsigned numBits, bool narrowRange, 16 bool isSigned, MLIRContext *ctx, 17 Type &storageType, int64_t &qmin, 18 int64_t &qmax) { 19 // Hard-coded type mapping from TFLite. 20 if (numBits <= 8) { 21 storageType = IntegerType::get(ctx, 8); 22 if (isSigned) { 23 qmin = -128; 24 qmax = 127; 25 } else { 26 qmin = 0; 27 qmax = 255; 28 } 29 } else if (numBits <= 16) { 30 storageType = IntegerType::get(ctx, 16); 31 if (isSigned) { 32 qmin = -32768; 33 qmax = 32767; 34 } else { 35 qmin = 0; 36 qmax = 65535; 37 } 38 } else if (numBits <= 32) { 39 storageType = IntegerType::get(ctx, 32); 40 if (isSigned) { 41 qmin = std::numeric_limits<int32_t>::min(); 42 qmax = std::numeric_limits<int32_t>::max(); 43 } else { 44 qmin = std::numeric_limits<uint32_t>::min(); 45 qmax = std::numeric_limits<uint32_t>::max(); 46 } 47 } else { 48 return true; 49 } 50 51 // Handle narrowRange. 52 if (narrowRange) { 53 qmin += 1; 54 } 55 return false; 56 } 57 58 // This is a specific implementation of nudging: 59 // If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted 60 // to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero 61 // point is derived from the shifted range, and the scale isn't changed. As 62 // a consequence some values, which are supposed in the original [rmin, rmax] 63 // range will be outside the shifted range and be clamped during quantization. 64 // TODO: we should nudge the scale as well, but that requires the 65 // fake quant op used in the training to use the nudged scale as well. 66 static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, 67 double rmax, double &scale, 68 int64_t &nudgedZeroPoint) { 69 // Determine the scale. 70 const double qminDouble = qmin; 71 const double qmaxDouble = qmax; 72 scale = (rmax - rmin) / (qmaxDouble - qminDouble); 73 74 // Zero point computation. 75 // In float, solve the affine equation for any known pair 76 // (real value, corresponding quantized value), of which, two such pairs 77 // are known: (rmin, qmin), (rmax, qmax). 78 // The arithmetic error on the zero point computed from either pair will be 79 // roughly machine_epsilon * (sum of absolute values of terms). 80 // Use the variant that adds the smaller error. 81 const double zeroPointFromMin = qminDouble - rmin / scale; 82 const double zeroPointFromMinError = 83 std::abs(qminDouble) + std::abs(rmin / scale); 84 const double zeroPointFromMax = qmaxDouble - rmax / scale; 85 const double zeroPointFromMaxError = 86 std::abs(qmaxDouble) + std::abs(rmax / scale); 87 88 const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError) 89 ? zeroPointFromMin 90 : zeroPointFromMax; 91 92 // Now nudge the zero point to be an integer. 93 nudgedZeroPoint = 0; 94 if (zeroPointDouble < qminDouble) { 95 nudgedZeroPoint = qmin; 96 } else if (zeroPointDouble > qmaxDouble) { 97 nudgedZeroPoint = qmax; 98 } else { 99 nudgedZeroPoint = round(zeroPointDouble); 100 } 101 102 // By construction, the nudged zero point should always be in range. 103 assert(nudgedZeroPoint >= qmin); 104 assert(nudgedZeroPoint <= qmax); 105 } 106 107 UniformQuantizedType 108 mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, 109 double rmax, bool narrowRange, 110 Type expressedType, bool isSigned) { 111 MLIRContext *ctx = expressedType.getContext(); 112 unsigned flags = isSigned ? QuantizationFlags::Signed : 0; 113 Type storageType; 114 int64_t qmin; 115 int64_t qmax; 116 if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType, 117 qmin, qmax)) { 118 return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits, 119 nullptr); 120 } 121 122 // Special case where min/max is close enough. The tensor contents are all 123 // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero 124 // points and dequantized to 0.0. 125 if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) { 126 return UniformQuantizedType::getChecked( 127 loc, flags, storageType, expressedType, 1.0, qmin, qmin, qmax); 128 } 129 130 double scale; 131 int64_t nudgedZeroPoint; 132 getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); 133 134 return UniformQuantizedType::getChecked(loc, flags, storageType, 135 expressedType, scale, nudgedZeroPoint, 136 qmin, qmax); 137 } 138 139 UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType( 140 Location loc, unsigned numBits, int32_t quantizedDimension, 141 ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange, 142 Type expressedType, bool isSigned) { 143 size_t axisSize = rmins.size(); 144 if (axisSize != rmaxs.size()) { 145 return (emitError(loc, "mismatched per-axis min and max size: ") 146 << axisSize << " vs. " << rmaxs.size(), 147 nullptr); 148 } 149 150 MLIRContext *ctx = expressedType.getContext(); 151 Type storageType; 152 int64_t qmin; 153 int64_t qmax; 154 if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType, 155 qmin, qmax)) { 156 return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits, 157 nullptr); 158 } 159 160 SmallVector<double, 4> scales; 161 SmallVector<int64_t, 4> zeroPoints; 162 scales.reserve(axisSize); 163 zeroPoints.reserve(axisSize); 164 for (size_t axis = 0; axis != axisSize; ++axis) { 165 double rmin = rmins[axis]; 166 double rmax = rmaxs[axis]; 167 if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) { 168 scales.push_back(1.0); 169 zeroPoints.push_back(qmin); 170 continue; 171 } 172 173 double scale; 174 int64_t nudgedZeroPoint; 175 getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); 176 scales.push_back(scale); 177 zeroPoints.push_back(nudgedZeroPoint); 178 } 179 180 unsigned flags = isSigned ? QuantizationFlags::Signed : 0; 181 return UniformQuantizedPerAxisType::getChecked( 182 loc, flags, storageType, expressedType, scales, zeroPoints, 183 quantizedDimension, qmin, qmax); 184 } 185