1 //===- QuantUtils.cpp -----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains TOSA numerical support functions and quantization 10 // attribute builders. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" 15 16 using namespace mlir; 17 using namespace mlir::tosa; 18 19 /// From a scale value, generates multiplier and shift values where 20 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that 21 /// multiplier = mantissa*2^shift for 16-bit scaling. 22 static void computeMultiplierAndShiftTosaScale16(double scale, 23 int32_t &multiplier, 24 int32_t &shift) { 25 26 const double mantissa = std::frexp(scale, &shift); 27 auto shiftedM = std::round(mantissa * (int64_t(1) << 15)); 28 29 // Can't be greater than 1.0. 30 assert(shiftedM <= (int64_t(1) << 15) && 31 "Shifted mantissa exceeds 16 signed bits"); 32 33 if (shiftedM == (int64_t(1) << 15)) { 34 shiftedM /= 2; 35 shift++; 36 } 37 38 // TOSA expects right shift to be positive and embed (1 << 15) into right 39 // shift bits. 40 shift = (-shift) + 15; 41 42 assert(shiftedM <= std::numeric_limits<int32_t>::max() && 43 "Shifted mantissa exceeds 32-bit signed output type"); 44 45 multiplier = static_cast<int32_t>(shiftedM); 46 47 // Shifting tops out at 62 bits. Right shift to make 62 bits the max. 48 // The limit of 62 on shift allows the shift to be decomposed as 49 // two right shifts of 31. 50 if (shift > 62) { 51 // Shifting the multiplier by more than 31-bits is unnecessary. 52 multiplier = multiplier >> std::min<int32_t>(31, shift - 62); 53 shift = 62; 54 } 55 } 56 57 /// From a scale value, generates multiplier and shift values where 58 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that 59 /// multiplier = mantissa*2^shift for 32-bit scaling. 60 static void computeMultiplierAndShiftTosaScale32(double scale, 61 int32_t &multiplier, 62 int32_t &shift) { 63 64 const double mantissa = std::frexp(scale, &shift); 65 auto shiftedM = std::round(mantissa * (int64_t(1) << 31)); 66 67 // Can't be greater than 1.0. 68 assert(shiftedM <= (int64_t(1) << 31) && 69 "Shifted mantissa exceeds 32 signed bits"); 70 if (shiftedM == (int64_t(1) << 31)) { 71 shiftedM /= 2; 72 shift++; 73 } 74 75 // TOSA expects right shift to be positive, and embed (1 << 31) into right 76 // shift bits. 77 shift = (-shift) + 31; 78 79 assert(shiftedM <= std::numeric_limits<int32_t>::max() && 80 "Shifted mantissa exceeds 32-bit signed output type"); 81 82 multiplier = static_cast<int32_t>(shiftedM); 83 84 // Shifting tops out at 62 bits. Right shift to make 62 bits the max. 85 // The limit of 62 on shift allows the shift to be decomposed as 86 // two right shifts of 31. 87 if (shift > 62) { 88 // Shifting the multiplier by more than 32-bits is unnecessary. 89 multiplier = multiplier >> std::min<int32_t>(31, shift - 62); 90 shift = 62; 91 } 92 } 93 94 /// Generates a quantized multiplier/shift from double. 95 void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier, 96 int32_t &shift, int32_t scaleWidth) { 97 98 switch (scaleWidth) { 99 case 16: 100 computeMultiplierAndShiftTosaScale16(scale, multiplier, shift); 101 return; 102 case 32: 103 computeMultiplierAndShiftTosaScale32(scale, multiplier, shift); 104 return; 105 default: 106 assert(0 && "Unsupported Tosa quantized_scale regime specified!"); 107 } 108 } 109 110 #define GET_UQTYPE(input_type) \ 111 ((input_type).getElementType().dyn_cast<quant::UniformQuantizedType>()) 112 #define GET_QTYPE(input_type) \ 113 ((input_type).getElementType().dyn_cast<quant::QuantizedType>()) 114 115 /// Method to build ConvOpQuantizationAttr, called from 116 /// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: 117 /// input_zp: input zeropoint 118 /// weight_zp: weight zeropoint. 119 ConvOpQuantizationAttr 120 mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input, 121 Value weight) { 122 123 auto inputType = dyn_cast<ShapedType>(input.getType()); 124 auto weightType = dyn_cast<ShapedType>(weight.getType()); 125 126 if (!inputType || !weightType) 127 return nullptr; 128 129 auto inputQType = GET_UQTYPE(inputType); 130 auto weightPerTensorQType = GET_UQTYPE(weightType); 131 auto weightPerAxisQType = 132 dyn_cast<quant::UniformQuantizedPerAxisType>(weightType.getElementType()); 133 134 // Weights must be either per-tensor quantized or per-axis quantized. 135 assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) && 136 "Weights must be either per-tensor or per-axis quantized"); 137 138 // Either all quantized or all not quantized. 139 assert(!((bool)inputQType ^ 140 ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) && 141 "Inputs and weights must be all quantized or all not quantized"); 142 143 if (inputQType) { 144 int64_t inputZp = inputQType.getZeroPoint(); 145 int64_t weightZp = 0; 146 147 if (weightPerTensorQType) { 148 weightZp = weightPerTensorQType.getZeroPoint(); 149 } else if (weightPerAxisQType) { 150 weightZp = weightPerAxisQType.getZeroPoints().front(); 151 } 152 153 return builder.getAttr<tosa::ConvOpQuantizationAttr>(inputZp, weightZp); 154 } 155 156 return nullptr; 157 } 158 159 /// Builds MatMulOpQuantizationAttr, called from 160 /// MatMulOpQuantInfoBuilder: 161 /// aZp: input a zeropoint 162 /// bZp: input b zeropoint. 163 MatMulOpQuantizationAttr 164 mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, 165 Value b) { 166 167 auto aType = dyn_cast<ShapedType>(a.getType()); 168 auto bType = dyn_cast<ShapedType>(b.getType()); 169 170 if (!aType || !bType) 171 return nullptr; 172 173 auto aQType = GET_UQTYPE(aType); 174 auto bQType = GET_UQTYPE(bType); 175 176 // A and B are either all quantized or all not quantized. 177 assert(!((bool)aQType ^ (bool)bQType) && 178 "Matmul operands must be all quantized or all not quantized"); 179 180 if (aQType) { 181 return builder.getAttr<tosa::MatMulOpQuantizationAttr>( 182 aQType.getZeroPoint(), bQType.getZeroPoint()); 183 } 184 185 return nullptr; 186 } 187 188 /// Builds UnaryOpQuantizationAttr 189 /// UnaryOpQuantInfoBuilder: 190 /// inputZp: input zeropoint 191 /// outputZp: output zeropoint. 192 UnaryOpQuantizationAttr 193 mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, 194 Type outputRawType) { 195 196 auto inputType = dyn_cast<ShapedType>(input.getType()); 197 auto outputType = dyn_cast<ShapedType>(outputRawType); 198 199 if (!inputType || !outputType) 200 return nullptr; 201 202 auto inputQType = GET_UQTYPE(inputType); 203 auto outputQType = GET_UQTYPE(outputType); 204 205 // Either all quantized or all not quantized. 206 assert(!((bool)inputQType ^ (bool)outputQType) && 207 "Unary inputs/outputs must be all quantized or all not quantized"); 208 209 if (inputQType) { 210 return builder.getAttr<UnaryOpQuantizationAttr>(inputQType.getZeroPoint(), 211 outputQType.getZeroPoint()); 212 } 213 214 return nullptr; 215 } 216 217 /// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: 218 /// inputZp: input zeropoint. 219 PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder, 220 Value input) { 221 222 auto inputType = dyn_cast<ShapedType>(input.getType()); 223 224 if (!inputType) 225 return nullptr; 226 227 auto inputQType = GET_UQTYPE(inputType); 228 229 if (inputQType) { 230 return builder.getAttr<tosa::PadOpQuantizationAttr>( 231 inputQType.getZeroPoint()); 232 } 233 234 return nullptr; 235 } 236 237 /// Builds output type for a quantized ConvOp with the right bitwidth. 238 /// This is called by the builder when dealing with quantized content. 239 Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, 240 Value input, Value weight) { 241 242 auto inputType = dyn_cast<ShapedType>(input.getType()); 243 auto weightType = dyn_cast<ShapedType>(weight.getType()); 244 245 assert(inputType && weightType && 246 "Could not extract input or weight tensors from Conv op"); 247 248 auto inputQType = GET_QTYPE(inputType); 249 auto weightQType = GET_QTYPE(weightType); 250 251 assert(inputQType && weightQType && 252 "Could not extract input or weight tensor types from Conv op"); 253 254 unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); 255 unsigned weightBits = weightQType.getStorageTypeIntegralWidth(); 256 257 auto outputShapedType = dyn_cast<ShapedType>(outputType); 258 assert(outputShapedType && 259 "Could not extract output shape type from Conv op"); 260 261 IntegerType accElementType; 262 if (inputBits == 16 && weightBits == 8) 263 accElementType = builder.getIntegerType(48); 264 else 265 accElementType = builder.getI32Type(); 266 auto accType = outputShapedType.clone(accElementType); 267 return accType; 268 } 269 270 /// Builds Tosa quantization attributes from min/max values. 271 Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType, 272 Attribute minAttr, Attribute maxAttr, 273 IntegerAttr quantBits, int filterQuantDim, 274 bool isSigned, BoolAttr narrowRange) { 275 276 quant::QuantizedType retType; 277 278 auto convfunc = 279 quant::ExpressedToQuantizedConverter::forInputType(inputDType); 280 281 auto minElems = dyn_cast<DenseFPElementsAttr>(minAttr); 282 auto maxElems = dyn_cast<DenseFPElementsAttr>(maxAttr); 283 284 SmallVector<double, 2> min, max; 285 286 // At least one is per-axis quantized elementsattr. 287 if (minElems || maxElems) { 288 // Must have the same number of elements. 289 if (minElems.getNumElements() != maxElems.getNumElements()) 290 return {}; 291 min.reserve(minElems.getNumElements()); 292 max.reserve(maxElems.getNumElements()); 293 for (auto i : minElems) 294 min.push_back(FloatAttr::getValueAsDouble(i)); 295 for (auto i : maxElems) 296 max.push_back(FloatAttr::getValueAsDouble(i)); 297 } else { // Just a single FP value. 298 auto minVal = dyn_cast<FloatAttr>(minAttr); 299 if (minVal) 300 min.push_back(minVal.getValueAsDouble()); 301 else 302 return {}; 303 auto maxVal = dyn_cast<FloatAttr>(maxAttr); 304 if (maxVal) 305 max.push_back(maxVal.getValueAsDouble()); 306 else 307 return {}; 308 } 309 310 if (min.size() == max.size()) { 311 if (min.size() == 1) { // Per-tensor quantization with one min/max pair. 312 retType = quant::fakeQuantAttrsToType( 313 builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], 314 narrowRange.getValue(), convfunc.expressedType, isSigned); 315 } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. 316 auto shape = dyn_cast<ShapedType>(inputDType); 317 if (!shape) 318 return {}; 319 if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { 320 retType = quant::fakeQuantAttrsToType( 321 builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0], 322 max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); 323 } 324 } else { 325 return {}; 326 } 327 } else { 328 return {}; 329 } 330 331 if (!retType) 332 return {}; 333 334 return convfunc.convert(retType); 335 } 336 337 /// Builds Tosa quantization attributes from min/max values. 338 TypeAttr 339 mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, 340 Attribute minAttr, Attribute maxAttr, 341 IntegerAttr quantBits, int filterQuantDim, 342 bool isSigned, BoolAttr narrowRange) { 343 344 return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr, 345 maxAttr, quantBits, filterQuantDim, 346 isSigned, narrowRange)); 347 } 348