1 //===- QuantUtils.cpp -----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains TOSA numerical support functions and quantization 10 // attribute builders. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" 15 16 using namespace mlir; 17 using namespace mlir::tosa; 18 19 /// From a scale value, generates multiplier and shift values where 20 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that 21 /// multiplier = mantissa*2^shift for 16-bit scaling. 22 static void computeMultiplierAndShiftTosaScale16(double scale, 23 int32_t &multiplier, 24 int32_t &shift) { 25 26 const double mantissa = std::frexp(scale, &shift); 27 auto shiftedM = std::round(mantissa * (int64_t(1) << 15)); 28 29 // Can't be greater than 1.0. 30 assert(shiftedM <= (int64_t(1) << 15) && 31 "Shifted mantissa exceeds 16 signed bits"); 32 33 if (shiftedM == (int64_t(1) << 15)) { 34 shiftedM /= 2; 35 shift++; 36 } 37 38 // TOSA expects right shift to be positive and embed (1 << 15) into right 39 // shift bits. 40 shift = (-shift) + 15; 41 42 assert(shiftedM <= std::numeric_limits<int32_t>::max() && 43 "Shifted mantissa exceeds 32-bit signed output type"); 44 45 multiplier = static_cast<int32_t>(shiftedM); 46 47 // Shifting tops out at 63 bits. Right shift to make 63 bits the max. 48 if (shift > 63) { 49 // Shifting the multiplier by more than 32-bits is unnecessary. 50 multiplier = multiplier >> std::min<int32_t>(32, shift - 63); 51 shift = 63; 52 } 53 } 54 55 /// From a scale value, generates multiplier and shift values where 56 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that 57 /// multiplier = mantissa*2^shift for 32-bit scaling. 58 static void computeMultiplierAndShiftTosaScale32(double scale, 59 int32_t &multiplier, 60 int32_t &shift) { 61 62 const double mantissa = std::frexp(scale, &shift); 63 auto shiftedM = std::round(mantissa * (int64_t(1) << 31)); 64 65 // Can't be greater than 1.0. 66 assert(shiftedM <= (int64_t(1) << 31) && 67 "Shifted mantissa exceeds 32 signed bits"); 68 if (shiftedM == (int64_t(1) << 31)) { 69 shiftedM /= 2; 70 shift++; 71 } 72 73 // TOSA expects right shift to be positive, and embed (1 << 31) into right 74 // shift bits. 75 shift = (-shift) + 31; 76 77 assert(shiftedM <= std::numeric_limits<int32_t>::max() && 78 "Shifted mantissa exceeds 32-bit signed output type"); 79 80 multiplier = static_cast<int32_t>(shiftedM); 81 82 // Shifting tops out at 63 bits. Right shift to make 63 bits the max. 83 if (shift > 63) { 84 // Shifting the multiplier by more than 32-bits is unnecessary. 85 multiplier = multiplier >> std::min<int32_t>(32, shift - 63); 86 shift = 63; 87 } 88 } 89 90 /// Generates a quantized multiplier/shift from double. 91 void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier, 92 int32_t &shift, int32_t scaleWidth) { 93 94 switch (scaleWidth) { 95 case 16: 96 computeMultiplierAndShiftTosaScale16(scale, multiplier, shift); 97 return; 98 case 32: 99 computeMultiplierAndShiftTosaScale32(scale, multiplier, shift); 100 return; 101 default: 102 assert(0 && "Unsupported Tosa quantized_scale regime specified!"); 103 } 104 } 105 106 #define GET_UQTYPE(input_type) \ 107 ((input_type).getElementType().dyn_cast<quant::UniformQuantizedType>()) 108 #define GET_QTYPE(input_type) \ 109 ((input_type).getElementType().dyn_cast<quant::QuantizedType>()) 110 111 /// Method to build ConvOpQuantizationAttr, called from 112 /// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: 113 /// input_zp: input zeropoint 114 /// weight_zp: weight zeropoint. 115 ConvOpQuantizationAttr 116 mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input, 117 Value weight) { 118 119 auto inputType = input.getType().dyn_cast<ShapedType>(); 120 auto weightType = weight.getType().dyn_cast<ShapedType>(); 121 122 if (!inputType || !weightType) 123 return nullptr; 124 125 auto inputQType = GET_UQTYPE(inputType); 126 auto weightPerTensorQType = GET_UQTYPE(weightType); 127 auto weightPerAxisQType = weightType.getElementType() 128 .dyn_cast<quant::UniformQuantizedPerAxisType>(); 129 130 // Weights must be either per-tensor quantized or per-axis quantized. 131 assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) && 132 "Weights must be either per-tensor or per-axis quantized"); 133 134 // Either all quantized or all not quantized. 135 assert(!((bool)inputQType ^ 136 ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) && 137 "Inputs and weights must be all quantized or all not quantized"); 138 139 if (inputQType) { 140 141 int64_t inputZp = inputQType.getZeroPoint(); 142 int64_t weightZp = 0; 143 144 if (weightPerTensorQType) { 145 weightZp = weightPerTensorQType.getZeroPoint(); 146 } else if (weightPerAxisQType) { 147 weightZp = weightPerAxisQType.getZeroPoints().front(); 148 } 149 150 auto quantAttr = tosa::ConvOpQuantizationAttr::get( 151 builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(weightZp), 152 builder.getContext()); 153 154 return quantAttr; 155 } 156 157 return nullptr; 158 } 159 160 /// Builds MatMulOpQuantizationAttr, called from 161 /// MatMulOpQuantInfoBuilder: 162 /// aZp: input a zeropoint 163 /// bZp: input b zeropoint. 164 MatMulOpQuantizationAttr 165 mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, 166 Value b) { 167 168 auto aType = a.getType().dyn_cast<ShapedType>(); 169 auto bType = b.getType().dyn_cast<ShapedType>(); 170 171 if (!aType || !bType) 172 return nullptr; 173 174 auto aQType = GET_UQTYPE(aType); 175 auto bQType = GET_UQTYPE(bType); 176 177 // A and B are either all quantized or all not quantized. 178 assert(!((bool)aQType ^ (bool)bQType) && 179 "Matmul operands must be all quantized or all not quantized"); 180 181 if (aQType) { 182 183 int64_t aZp = aQType.getZeroPoint(); 184 int64_t bZp = bQType.getZeroPoint(); 185 186 auto quantAttr = tosa::MatMulOpQuantizationAttr::get( 187 builder.getI32IntegerAttr(aZp), builder.getI32IntegerAttr(bZp), 188 builder.getContext()); 189 190 return quantAttr; 191 } 192 193 return nullptr; 194 } 195 196 /// Builds UnaryOpQuantizationAttr 197 /// UnaryOpQuantInfoBuilder: 198 /// inputZp: input zeropoint 199 /// outputZp: output zeropoint. 200 UnaryOpQuantizationAttr 201 mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, 202 Type outputRawType) { 203 204 auto inputType = input.getType().dyn_cast<ShapedType>(); 205 auto outputType = outputRawType.dyn_cast<ShapedType>(); 206 207 if (!inputType || !outputType) 208 return nullptr; 209 210 auto inputQType = GET_UQTYPE(inputType); 211 auto outputQType = GET_UQTYPE(outputType); 212 213 // Either all quantized or all not quantized. 214 assert(!((bool)inputQType ^ (bool)outputQType) && 215 "Unary inputs/outputs must be all quantized or all not quantized"); 216 217 if (inputQType) { 218 219 int64_t inputZp = inputQType.getZeroPoint(); 220 int64_t outputZp = outputQType.getZeroPoint(); 221 222 auto quantAttr = tosa::UnaryOpQuantizationAttr::get( 223 builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(outputZp), 224 builder.getContext()); 225 226 return quantAttr; 227 } 228 229 return nullptr; 230 } 231 232 /// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: 233 /// inputZp: input zeropoint. 234 PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder, 235 Value input) { 236 237 auto inputType = input.getType().dyn_cast<ShapedType>(); 238 239 if (!inputType) 240 return nullptr; 241 242 auto inputQType = GET_UQTYPE(inputType); 243 244 if (inputQType) { 245 246 int64_t inputZp = inputQType.getZeroPoint(); 247 248 auto quantAttr = tosa::PadOpQuantizationAttr::get( 249 builder.getI32IntegerAttr(inputZp), builder.getContext()); 250 251 return quantAttr; 252 } 253 254 return nullptr; 255 } 256 257 /// Builds output type for a quantized ConvOp with the right bitwidth. 258 /// This is called by the builder when dealing with quantized content. 259 Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, 260 Value input, Value weight) { 261 262 auto inputType = input.getType().dyn_cast<ShapedType>(); 263 auto weightType = weight.getType().dyn_cast<ShapedType>(); 264 265 assert(inputType && weightType && 266 "Could not extract input or weight tensors from Conv op"); 267 268 auto inputQType = GET_QTYPE(inputType); 269 auto weightQType = GET_QTYPE(weightType); 270 271 assert(inputQType && weightQType && 272 "Could not extract input or weight tensor types from Conv op"); 273 274 unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); 275 unsigned weightBits = weightQType.getStorageTypeIntegralWidth(); 276 277 auto outputShapedType = outputType.dyn_cast<ShapedType>(); 278 assert(outputShapedType && 279 "Could not extract output shape type from Conv op"); 280 281 IntegerType accElementType; 282 if (inputBits == 16 && weightBits == 8) 283 accElementType = builder.getIntegerType(48); 284 else 285 accElementType = builder.getI32Type(); 286 auto accType = outputShapedType.clone(accElementType); 287 return accType; 288 } 289 290 /// Builds Tosa quantization attributes from min/max values. 291 Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType, 292 Attribute minAttr, Attribute maxAttr, 293 IntegerAttr quantBits, int filterQuantDim, 294 bool isSigned, BoolAttr narrowRange) { 295 296 quant::QuantizedType retType; 297 298 auto convfunc = 299 quant::ExpressedToQuantizedConverter::forInputType(inputDType); 300 301 auto minElems = minAttr.dyn_cast<DenseFPElementsAttr>(); 302 auto maxElems = maxAttr.dyn_cast<DenseFPElementsAttr>(); 303 304 SmallVector<double, 2> min, max; 305 306 // At least one is per-axis quantized elementsattr. 307 if (minElems || maxElems) { 308 // Must have the same number of elements. 309 if (minElems.getNumElements() != maxElems.getNumElements()) 310 return {}; 311 min.reserve(minElems.getNumElements()); 312 max.reserve(maxElems.getNumElements()); 313 for (auto i : minElems) 314 min.push_back(FloatAttr::getValueAsDouble(i)); 315 for (auto i : maxElems) 316 max.push_back(FloatAttr::getValueAsDouble(i)); 317 } else { // Just a single FP value. 318 auto minVal = minAttr.dyn_cast<FloatAttr>(); 319 if (minVal) 320 min.push_back(minVal.getValueAsDouble()); 321 else 322 return {}; 323 auto maxVal = maxAttr.dyn_cast<FloatAttr>(); 324 if (maxVal) 325 max.push_back(maxVal.getValueAsDouble()); 326 else 327 return {}; 328 } 329 330 if (min.size() == max.size()) { 331 if (min.size() == 1) { // Per-tensor quantization with one min/max pair. 332 retType = quant::fakeQuantAttrsToType( 333 builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], 334 narrowRange.getValue(), convfunc.expressedType, isSigned); 335 } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. 336 auto shape = inputDType.dyn_cast<ShapedType>(); 337 if (!shape) 338 return {}; 339 if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { 340 retType = quant::fakeQuantAttrsToType( 341 builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0], 342 max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); 343 } 344 } else { 345 return {}; 346 } 347 } else { 348 return {}; 349 } 350 351 if (!retType) 352 return {}; 353 354 return convfunc.convert(retType); 355 } 356 357 /// Builds Tosa quantization attributes from min/max values. 358 TypeAttr 359 mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, 360 Attribute minAttr, Attribute maxAttr, 361 IntegerAttr quantBits, int filterQuantDim, 362 bool isSigned, BoolAttr narrowRange) { 363 364 return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr, 365 maxAttr, quantBits, filterQuantDim, 366 isSigned, narrowRange)); 367 } 368