1363dd3f3SRob Suderman //===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===// 2363dd3f3SRob Suderman // 3363dd3f3SRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4363dd3f3SRob Suderman // See https://llvm.org/LICENSE.txt for license information. 5363dd3f3SRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6363dd3f3SRob Suderman // 7363dd3f3SRob Suderman //===----------------------------------------------------------------------===// 8363dd3f3SRob Suderman 9*852b6486SRafael Ubal #include "mlir/Dialect/Quant/IR/QuantTypes.h" 10*852b6486SRafael Ubal #include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h" 11363dd3f3SRob Suderman 12363dd3f3SRob Suderman using namespace mlir; 13363dd3f3SRob Suderman using namespace mlir::quant; 14363dd3f3SRob Suderman 15363dd3f3SRob Suderman static bool getDefaultStorageParams(unsigned numBits, bool narrowRange, 16363dd3f3SRob Suderman bool isSigned, MLIRContext *ctx, 17363dd3f3SRob Suderman Type &storageType, int64_t &qmin, 18363dd3f3SRob Suderman int64_t &qmax) { 19363dd3f3SRob Suderman // Hard-coded type mapping from TFLite. 20363dd3f3SRob Suderman if (numBits <= 8) { 211b97cdf8SRiver Riddle storageType = IntegerType::get(ctx, 8); 22363dd3f3SRob Suderman if (isSigned) { 23363dd3f3SRob Suderman qmin = -128; 24363dd3f3SRob Suderman qmax = 127; 25363dd3f3SRob Suderman } else { 26363dd3f3SRob Suderman qmin = 0; 27363dd3f3SRob Suderman qmax = 255; 28363dd3f3SRob Suderman } 29363dd3f3SRob Suderman } else if (numBits <= 16) { 301b97cdf8SRiver Riddle storageType = IntegerType::get(ctx, 16); 31363dd3f3SRob Suderman if (isSigned) { 32363dd3f3SRob Suderman qmin = -32768; 33363dd3f3SRob Suderman qmax = 32767; 34363dd3f3SRob Suderman } else { 35363dd3f3SRob Suderman qmin = 0; 36363dd3f3SRob Suderman qmax = 65535; 37363dd3f3SRob Suderman } 38b578c92aSFeng Liu } else if (numBits <= 32) { 391b97cdf8SRiver Riddle storageType = IntegerType::get(ctx, 32); 40b578c92aSFeng Liu if (isSigned) { 41b578c92aSFeng Liu qmin = std::numeric_limits<int32_t>::min(); 42b578c92aSFeng Liu qmax = std::numeric_limits<int32_t>::max(); 43b578c92aSFeng Liu } else { 44b578c92aSFeng Liu qmin = std::numeric_limits<uint32_t>::min(); 45b578c92aSFeng Liu qmax = std::numeric_limits<uint32_t>::max(); 46b578c92aSFeng Liu } 47363dd3f3SRob Suderman } else { 48363dd3f3SRob Suderman return true; 49363dd3f3SRob Suderman } 50363dd3f3SRob Suderman 51363dd3f3SRob Suderman // Handle narrowRange. 52363dd3f3SRob Suderman if (narrowRange) { 53363dd3f3SRob Suderman qmin += 1; 54363dd3f3SRob Suderman } 55363dd3f3SRob Suderman return false; 56363dd3f3SRob Suderman } 57363dd3f3SRob Suderman 58363dd3f3SRob Suderman // This is a specific implementation of nudging: 59363dd3f3SRob Suderman // If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted 60363dd3f3SRob Suderman // to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero 61363dd3f3SRob Suderman // point is derived from the shifted range, and the scale isn't changed. As 62363dd3f3SRob Suderman // a consequence some values, which are supposed in the original [rmin, rmax] 63363dd3f3SRob Suderman // range will be outside the shifted range and be clamped during quantization. 649db53a18SRiver Riddle // TODO: we should nudge the scale as well, but that requires the 65363dd3f3SRob Suderman // fake quant op used in the training to use the nudged scale as well. 66363dd3f3SRob Suderman static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, 67363dd3f3SRob Suderman double rmax, double &scale, 68363dd3f3SRob Suderman int64_t &nudgedZeroPoint) { 69363dd3f3SRob Suderman // Determine the scale. 70363dd3f3SRob Suderman const double qminDouble = qmin; 71363dd3f3SRob Suderman const double qmaxDouble = qmax; 72363dd3f3SRob Suderman scale = (rmax - rmin) / (qmaxDouble - qminDouble); 73363dd3f3SRob Suderman 74363dd3f3SRob Suderman // Zero point computation. 75363dd3f3SRob Suderman // In float, solve the affine equation for any known pair 76363dd3f3SRob Suderman // (real value, corresponding quantized value), of which, two such pairs 77363dd3f3SRob Suderman // are known: (rmin, qmin), (rmax, qmax). 78363dd3f3SRob Suderman // The arithmetic error on the zero point computed from either pair will be 79363dd3f3SRob Suderman // roughly machine_epsilon * (sum of absolute values of terms). 80363dd3f3SRob Suderman // Use the variant that adds the smaller error. 81363dd3f3SRob Suderman const double zeroPointFromMin = qminDouble - rmin / scale; 82363dd3f3SRob Suderman const double zeroPointFromMinError = 83363dd3f3SRob Suderman std::abs(qminDouble) + std::abs(rmin / scale); 84363dd3f3SRob Suderman const double zeroPointFromMax = qmaxDouble - rmax / scale; 85363dd3f3SRob Suderman const double zeroPointFromMaxError = 86363dd3f3SRob Suderman std::abs(qmaxDouble) + std::abs(rmax / scale); 87363dd3f3SRob Suderman 88363dd3f3SRob Suderman const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError) 89363dd3f3SRob Suderman ? zeroPointFromMin 90363dd3f3SRob Suderman : zeroPointFromMax; 91363dd3f3SRob Suderman 92363dd3f3SRob Suderman // Now nudge the zero point to be an integer. 93363dd3f3SRob Suderman nudgedZeroPoint = 0; 94363dd3f3SRob Suderman if (zeroPointDouble < qminDouble) { 95363dd3f3SRob Suderman nudgedZeroPoint = qmin; 96363dd3f3SRob Suderman } else if (zeroPointDouble > qmaxDouble) { 97363dd3f3SRob Suderman nudgedZeroPoint = qmax; 98363dd3f3SRob Suderman } else { 99363dd3f3SRob Suderman nudgedZeroPoint = round(zeroPointDouble); 100363dd3f3SRob Suderman } 101363dd3f3SRob Suderman 102363dd3f3SRob Suderman // By construction, the nudged zero point should always be in range. 103363dd3f3SRob Suderman assert(nudgedZeroPoint >= qmin); 104363dd3f3SRob Suderman assert(nudgedZeroPoint <= qmax); 105363dd3f3SRob Suderman } 106363dd3f3SRob Suderman 107363dd3f3SRob Suderman UniformQuantizedType 108363dd3f3SRob Suderman mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, 109363dd3f3SRob Suderman double rmax, bool narrowRange, 110363dd3f3SRob Suderman Type expressedType, bool isSigned) { 111363dd3f3SRob Suderman MLIRContext *ctx = expressedType.getContext(); 112363dd3f3SRob Suderman unsigned flags = isSigned ? QuantizationFlags::Signed : 0; 113363dd3f3SRob Suderman Type storageType; 114363dd3f3SRob Suderman int64_t qmin; 115363dd3f3SRob Suderman int64_t qmax; 116363dd3f3SRob Suderman if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType, 117363dd3f3SRob Suderman qmin, qmax)) { 118363dd3f3SRob Suderman return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits, 119363dd3f3SRob Suderman nullptr); 120363dd3f3SRob Suderman } 121363dd3f3SRob Suderman 122363dd3f3SRob Suderman // Special case where min/max is close enough. The tensor contents are all 123363dd3f3SRob Suderman // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero 124363dd3f3SRob Suderman // points and dequantized to 0.0. 125363dd3f3SRob Suderman if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) { 12606e25d56SRiver Riddle return UniformQuantizedType::getChecked( 12706e25d56SRiver Riddle loc, flags, storageType, expressedType, 1.0, qmin, qmin, qmax); 128363dd3f3SRob Suderman } 129363dd3f3SRob Suderman 130363dd3f3SRob Suderman double scale; 131363dd3f3SRob Suderman int64_t nudgedZeroPoint; 132363dd3f3SRob Suderman getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); 133363dd3f3SRob Suderman 13406e25d56SRiver Riddle return UniformQuantizedType::getChecked(loc, flags, storageType, 13506e25d56SRiver Riddle expressedType, scale, nudgedZeroPoint, 13606e25d56SRiver Riddle qmin, qmax); 137363dd3f3SRob Suderman } 138363dd3f3SRob Suderman 139363dd3f3SRob Suderman UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType( 140363dd3f3SRob Suderman Location loc, unsigned numBits, int32_t quantizedDimension, 141363dd3f3SRob Suderman ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange, 142363dd3f3SRob Suderman Type expressedType, bool isSigned) { 14302b6fb21SMehdi Amini size_t axisSize = rmins.size(); 14402b6fb21SMehdi Amini if (axisSize != rmaxs.size()) { 145363dd3f3SRob Suderman return (emitError(loc, "mismatched per-axis min and max size: ") 14602b6fb21SMehdi Amini << axisSize << " vs. " << rmaxs.size(), 147363dd3f3SRob Suderman nullptr); 148363dd3f3SRob Suderman } 149363dd3f3SRob Suderman 150363dd3f3SRob Suderman MLIRContext *ctx = expressedType.getContext(); 151363dd3f3SRob Suderman Type storageType; 152363dd3f3SRob Suderman int64_t qmin; 153363dd3f3SRob Suderman int64_t qmax; 154363dd3f3SRob Suderman if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType, 155363dd3f3SRob Suderman qmin, qmax)) { 156363dd3f3SRob Suderman return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits, 157363dd3f3SRob Suderman nullptr); 158363dd3f3SRob Suderman } 159363dd3f3SRob Suderman 160363dd3f3SRob Suderman SmallVector<double, 4> scales; 161363dd3f3SRob Suderman SmallVector<int64_t, 4> zeroPoints; 16202b6fb21SMehdi Amini scales.reserve(axisSize); 16302b6fb21SMehdi Amini zeroPoints.reserve(axisSize); 16402b6fb21SMehdi Amini for (size_t axis = 0; axis != axisSize; ++axis) { 165363dd3f3SRob Suderman double rmin = rmins[axis]; 166363dd3f3SRob Suderman double rmax = rmaxs[axis]; 167363dd3f3SRob Suderman if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) { 168363dd3f3SRob Suderman scales.push_back(1.0); 169363dd3f3SRob Suderman zeroPoints.push_back(qmin); 170363dd3f3SRob Suderman continue; 171363dd3f3SRob Suderman } 172363dd3f3SRob Suderman 173363dd3f3SRob Suderman double scale; 174363dd3f3SRob Suderman int64_t nudgedZeroPoint; 175363dd3f3SRob Suderman getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); 176363dd3f3SRob Suderman scales.push_back(scale); 177363dd3f3SRob Suderman zeroPoints.push_back(nudgedZeroPoint); 178363dd3f3SRob Suderman } 179363dd3f3SRob Suderman 180363dd3f3SRob Suderman unsigned flags = isSigned ? QuantizationFlags::Signed : 0; 181363dd3f3SRob Suderman return UniformQuantizedPerAxisType::getChecked( 18206e25d56SRiver Riddle loc, flags, storageType, expressedType, scales, zeroPoints, 18306e25d56SRiver Riddle quantizedDimension, qmin, qmax); 184363dd3f3SRob Suderman } 185