1 //===- QuantUtils.cpp -----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains TOSA numerical support functions and quantization 10 // attribute builders. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" 15 16 using namespace mlir; 17 using namespace mlir::tosa; 18 19 /// From a scale value, generates multiplier and shift values where 20 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that 21 /// multiplier = mantissa*2^shift for 16-bit scaling. 22 static void computeMultiplierAndShiftTosaScale16(double scale, 23 int32_t &multiplier, 24 int32_t &shift) { 25 26 const double mantissa = std::frexp(scale, &shift); 27 auto shiftedM = std::round(mantissa * (int64_t(1) << 15)); 28 29 // Can't be greater than 1.0. 30 assert(shiftedM <= (int64_t(1) << 15) && 31 "Shifted mantissa exceeds 16 signed bits"); 32 33 if (shiftedM == (int64_t(1) << 15)) { 34 shiftedM /= 2; 35 shift++; 36 } 37 38 // TOSA expects right shift to be positive and embed (1 << 15) into right 39 // shift bits. 40 shift = (-shift) + 15; 41 42 assert(shiftedM <= std::numeric_limits<int32_t>::max() && 43 "Shifted mantissa exceeds 32-bit signed output type"); 44 45 multiplier = static_cast<int32_t>(shiftedM); 46 } 47 48 /// From a scale value, generates multiplier and shift values where 49 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that 50 /// multiplier = mantissa*2^shift for 32-bit scaling. 51 static void computeMultiplierAndShiftTosaScale32(double scale, 52 int32_t &multiplier, 53 int32_t &shift) { 54 55 const double mantissa = std::frexp(scale, &shift); 56 auto shiftedM = std::round(mantissa * (int64_t(1) << 31)); 57 58 // Can't be greater than 1.0. 59 assert(shiftedM <= (int64_t(1) << 31) && 60 "Shifted mantissa exceeds 32 signed bits"); 61 if (shiftedM == (int64_t(1) << 31)) { 62 shiftedM /= 2; 63 shift++; 64 } 65 66 // TOSA expects right shift to be positive, and embed (1 << 31) into right 67 // shift bits. 68 shift = (-shift) + 31; 69 70 assert(shiftedM <= std::numeric_limits<int32_t>::max() && 71 "Shifted mantissa exceeds 32-bit signed output type"); 72 73 multiplier = static_cast<int32_t>(shiftedM); 74 } 75 76 /// Generates a quantized multiplier/shift from double. 77 void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier, 78 int32_t &shift, int32_t scaleWidth) { 79 80 switch (scaleWidth) { 81 case 16: 82 computeMultiplierAndShiftTosaScale16(scale, multiplier, shift); 83 return; 84 case 32: 85 computeMultiplierAndShiftTosaScale32(scale, multiplier, shift); 86 return; 87 default: 88 assert(0 && "Unsupported Tosa quantized_scale regime specified!"); 89 } 90 } 91 92 #define GET_UQTYPE(input_type) \ 93 ((input_type).getElementType().dyn_cast<quant::UniformQuantizedType>()) 94 #define GET_QTYPE(input_type) \ 95 ((input_type).getElementType().dyn_cast<quant::QuantizedType>()) 96 97 /// Method to build ConvOpQuantizationAttr, called from 98 /// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: 99 /// input_zp: input zeropoint 100 /// weight_zp: weight zeropoint. 101 ConvOpQuantizationAttr 102 mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input, 103 Value weight) { 104 105 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 106 auto weightType = weight.getType().dyn_cast<RankedTensorType>(); 107 108 if (!inputType || !weightType) 109 return nullptr; 110 111 auto inputQType = GET_UQTYPE(inputType); 112 auto weightPerTensorQType = GET_UQTYPE(weightType); 113 auto weightPerAxisQType = weightType.getElementType() 114 .dyn_cast<quant::UniformQuantizedPerAxisType>(); 115 116 // Weights must be either per-tensor quantized or per-axis quantized. 117 assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) && 118 "Weights must be either per-tensor or per-axis quantized"); 119 120 // Either all quantized or all not quantized. 121 assert(!((bool)inputQType ^ 122 ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) && 123 "Inputs and weights must be all quantized or all not quantized"); 124 125 if (inputQType) { 126 127 int64_t inputZp = inputQType.getZeroPoint(); 128 int64_t weightZp = 0; 129 130 if (weightPerTensorQType) { 131 weightZp = weightPerTensorQType.getZeroPoint(); 132 } else if (weightPerAxisQType) { 133 weightZp = weightPerAxisQType.getZeroPoints().front(); 134 } 135 136 auto quantAttr = tosa::ConvOpQuantizationAttr::get( 137 builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(weightZp), 138 builder.getContext()); 139 140 return quantAttr; 141 } 142 143 return nullptr; 144 } 145 146 /// Builds MatMulOpQuantizationAttr, called from 147 /// MatMulOpQuantInfoBuilder: 148 /// aZp: input a zeropoint 149 /// bZp: input b zeropoint. 150 MatMulOpQuantizationAttr 151 mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, 152 Value b) { 153 154 auto aType = a.getType().dyn_cast<RankedTensorType>(); 155 auto bType = b.getType().dyn_cast<RankedTensorType>(); 156 157 if (!aType || !bType) 158 return nullptr; 159 160 auto aQType = GET_UQTYPE(aType); 161 auto bQType = GET_UQTYPE(bType); 162 163 // A and B are either all quantized or all not quantized. 164 assert(!((bool)aQType ^ (bool)bQType) && 165 "Matmul operands must be all quantized or all not quantized"); 166 167 if (aQType) { 168 169 int64_t aZp = aQType.getZeroPoint(); 170 int64_t bZp = bQType.getZeroPoint(); 171 172 auto quantAttr = tosa::MatMulOpQuantizationAttr::get( 173 builder.getI32IntegerAttr(aZp), builder.getI32IntegerAttr(bZp), 174 builder.getContext()); 175 176 return quantAttr; 177 } 178 179 return nullptr; 180 } 181 182 /// Builds UnaryOpQuantizationAttr 183 /// UnaryOpQuantInfoBuilder: 184 /// inputZp: input zeropoint 185 /// outputZp: output zeropoint. 186 UnaryOpQuantizationAttr 187 mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, 188 Type outputRawType) { 189 190 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 191 auto outputType = outputRawType.dyn_cast<RankedTensorType>(); 192 193 if (!inputType || !outputType) 194 return nullptr; 195 196 auto inputQType = GET_UQTYPE(inputType); 197 auto outputQType = GET_UQTYPE(outputType); 198 199 // Either all quantized or all not quantized. 200 assert(!((bool)inputQType ^ (bool)outputQType) && 201 "Unary inputs/outputs must be all quantized or all not quantized"); 202 203 if (inputQType) { 204 205 int64_t inputZp = inputQType.getZeroPoint(); 206 int64_t outputZp = outputQType.getZeroPoint(); 207 208 auto quantAttr = tosa::UnaryOpQuantizationAttr::get( 209 builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(outputZp), 210 builder.getContext()); 211 212 return quantAttr; 213 } 214 215 return nullptr; 216 } 217 218 /// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: 219 /// inputZp: input zeropoint. 220 PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder, 221 Value input) { 222 223 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 224 225 if (!inputType) 226 return nullptr; 227 228 auto inputQType = GET_UQTYPE(inputType); 229 230 if (inputQType) { 231 232 int64_t inputZp = inputQType.getZeroPoint(); 233 234 auto quantAttr = tosa::PadOpQuantizationAttr::get( 235 builder.getI32IntegerAttr(inputZp), builder.getContext()); 236 237 return quantAttr; 238 } 239 240 return nullptr; 241 } 242 243 /// Builds output type for a quantized ConvOp with the right bitwidth. 244 /// This is called by the builder when dealing with quantized content. 245 Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, 246 Value input, Value weight) { 247 248 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 249 auto weightType = weight.getType().dyn_cast<RankedTensorType>(); 250 251 assert(inputType && weightType && 252 "Could not extract input or weight tensors from Conv op"); 253 254 auto inputQType = GET_QTYPE(inputType); 255 auto weightQType = GET_QTYPE(weightType); 256 257 assert(inputQType && weightQType && 258 "Could not extract input or weight tensor types from Conv op"); 259 260 unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); 261 unsigned weightBits = weightQType.getStorageTypeIntegralWidth(); 262 263 auto outputShapedType = outputType.dyn_cast<RankedTensorType>(); 264 assert(outputShapedType && 265 "Could not extract output shape type from Conv op"); 266 267 auto outputShape = outputShapedType.getShape(); 268 269 IntegerType accElementType; 270 if (inputBits == 16 && weightBits == 8) 271 accElementType = builder.getIntegerType(48); 272 else 273 accElementType = builder.getI32Type(); 274 auto accType = RankedTensorType::get(outputShape, accElementType); 275 return accType; 276 } 277 278 /// Builds Tosa quantization attributes from min/max values. 279 Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType, 280 Attribute minAttr, Attribute maxAttr, 281 IntegerAttr quantBits, int filterQuantDim, 282 bool isSigned, BoolAttr narrowRange) { 283 284 quant::QuantizedType retType; 285 286 auto convfunc = 287 quant::ExpressedToQuantizedConverter::forInputType(inputDType); 288 289 auto minElems = minAttr.dyn_cast<DenseFPElementsAttr>(); 290 auto maxElems = maxAttr.dyn_cast<DenseFPElementsAttr>(); 291 292 SmallVector<double, 2> min, max; 293 294 // At least one is per-axis quantized elementsattr. 295 if (minElems || maxElems) { 296 // Must have the same number of elements. 297 if (minElems.getNumElements() != maxElems.getNumElements()) 298 return {}; 299 min.reserve(minElems.getNumElements()); 300 max.reserve(maxElems.getNumElements()); 301 for (auto i : minElems) 302 min.push_back(FloatAttr::getValueAsDouble(i)); 303 for (auto i : maxElems) 304 max.push_back(FloatAttr::getValueAsDouble(i)); 305 } else { // Just a single FP value. 306 auto minVal = minAttr.dyn_cast<FloatAttr>(); 307 if (minVal) 308 min.push_back(minVal.getValueAsDouble()); 309 else 310 return {}; 311 auto maxVal = maxAttr.dyn_cast<FloatAttr>(); 312 if (maxVal) 313 max.push_back(maxVal.getValueAsDouble()); 314 else 315 return {}; 316 } 317 318 if (min.size() == max.size()) { 319 if (min.size() == 1) { // Per-tensor quantization with one min/max pair. 320 retType = quant::fakeQuantAttrsToType( 321 builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], 322 narrowRange.getValue(), convfunc.expressedType, isSigned); 323 } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. 324 auto shape = inputDType.dyn_cast<ShapedType>(); 325 if (!shape) 326 return {}; 327 if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { 328 retType = quant::fakeQuantAttrsToType( 329 builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0], 330 max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); 331 } 332 } else { 333 return {}; 334 } 335 } else { 336 return {}; 337 } 338 339 if (!retType) 340 return {}; 341 342 return convfunc.convert(retType); 343 } 344 345 /// Builds Tosa quantization attributes from min/max values. 346 TypeAttr 347 mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, 348 Attribute minAttr, Attribute maxAttr, 349 IntegerAttr quantBits, int filterQuantDim, 350 bool isSigned, BoolAttr narrowRange) { 351 352 return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr, 353 maxAttr, quantBits, filterQuantDim, 354 isSigned, narrowRange)); 355 } 356