1 //===- QuantUtils.cpp -----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains TOSA numerical support functions and quantization 10 // attribute builders. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" 15 16 using namespace mlir; 17 using namespace mlir::tosa; 18 19 /// From a scale value, generates multiplier and shift values where 20 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that 21 /// multiplier = mantissa*2^shift for 16-bit scaling. 22 void computeMultiplierAndShiftTosaScale16(double scale, int32_t &multiplier, 23 int32_t &shift) { 24 25 const double mantissa = std::frexp(scale, &shift); 26 auto shiftedM = std::round(mantissa * (int64_t(1) << 15)); 27 28 // Can't be greater than 1.0. 29 assert(shiftedM <= (int64_t(1) << 15) && 30 "Shifted mantissa exceeds 16 signed bits"); 31 32 if (shiftedM == (int64_t(1) << 15)) { 33 shiftedM /= 2; 34 shift++; 35 } 36 37 // TOSA expects right shift to be positive and embed (1 << 15) into right 38 // shift bits. 39 shift = (-shift) + 15; 40 41 assert(shiftedM <= std::numeric_limits<int32_t>::max() && 42 "Shifted mantissa exceeds 32-bit signed output type"); 43 44 multiplier = static_cast<int32_t>(shiftedM); 45 } 46 47 /// From a scale value, generates multiplier and shift values where 48 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that 49 /// multiplier = mantissa*2^shift for 32-bit scaling. 50 void computeMultiplierAndShiftTosaScale32(double scale, int32_t &multiplier, 51 int32_t &shift) { 52 53 const double mantissa = std::frexp(scale, &shift); 54 auto shiftedM = std::round(mantissa * (int64_t(1) << 31)); 55 56 // Can't be greater than 1.0. 57 assert(shiftedM <= (int64_t(1) << 31) && 58 "Shifted mantissa exceeds 32 signed bits"); 59 if (shiftedM == (int64_t(1) << 31)) { 60 shiftedM /= 2; 61 shift++; 62 } 63 64 // TOSA expects right shift to be positive, and embed (1 << 31) into right 65 // shift bits. 66 shift = (-shift) + 31; 67 68 assert(shiftedM <= std::numeric_limits<int32_t>::max() && 69 "Shifted mantissa exceeds 32-bit signed output type"); 70 71 multiplier = static_cast<int32_t>(shiftedM); 72 } 73 74 /// Generates a quantized multiplier/shift from double. 75 void computeMultiplierAndShift(double scale, int32_t &multiplier, 76 int32_t &shift, int32_t scaleWidth) { 77 78 switch (scaleWidth) { 79 case 16: 80 computeMultiplierAndShiftTosaScale16(scale, multiplier, shift); 81 return; 82 case 32: 83 computeMultiplierAndShiftTosaScale32(scale, multiplier, shift); 84 return; 85 default: 86 assert(0 && "Unsupported Tosa quantized_scale regime specified!"); 87 } 88 } 89 90 #define GET_UQTYPE(input_type) \ 91 ((input_type).getElementType().dyn_cast<quant::UniformQuantizedType>()) 92 #define GET_QTYPE(input_type) \ 93 ((input_type).getElementType().dyn_cast<quant::QuantizedType>()) 94 95 /// Method to build ConvOpQuantizationAttr, called from 96 /// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: 97 /// input_zp: input zeropoint 98 /// weight_zp: weight zeropoint. 99 ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, 100 Value input, Value weight) { 101 102 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 103 auto weightType = weight.getType().dyn_cast<RankedTensorType>(); 104 105 if (!inputType || !weightType) 106 return nullptr; 107 108 auto inputQType = GET_UQTYPE(inputType); 109 auto weightPerTensorQType = GET_UQTYPE(weightType); 110 auto weightPerAxisQType = weightType.getElementType() 111 .dyn_cast<quant::UniformQuantizedPerAxisType>(); 112 113 // Weights must be either per-tensor quantized or per-axis quantized. 114 assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) && 115 "Weights must be either per-tensor or per-axis quantized"); 116 117 // Either all quantized or all not quantized. 118 assert(!((bool)inputQType ^ 119 ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) && 120 "Inputs and weights must be all quantized or all not quantized"); 121 122 if (inputQType) { 123 124 int64_t inputZp = inputQType.getZeroPoint(); 125 int64_t weightZp = 0; 126 127 if (weightPerTensorQType) { 128 weightZp = weightPerTensorQType.getZeroPoint(); 129 } else if (weightPerAxisQType) { 130 weightZp = weightPerAxisQType.getZeroPoints().front(); 131 } 132 133 auto quantAttr = tosa::ConvOpQuantizationAttr::get( 134 builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(weightZp), 135 builder.getContext()); 136 137 return quantAttr; 138 } 139 140 return nullptr; 141 } 142 143 /// Builds MatMulOpQuantizationAttr, called from 144 /// MatMulOpQuantInfoBuilder: 145 /// aZp: input a zeropoint 146 /// bZp: input b zeropoint. 147 MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, 148 Value a, Value b) { 149 150 auto aType = a.getType().dyn_cast<RankedTensorType>(); 151 auto bType = b.getType().dyn_cast<RankedTensorType>(); 152 153 if (!aType || !bType) 154 return nullptr; 155 156 auto aQType = GET_UQTYPE(aType); 157 auto bQType = GET_UQTYPE(bType); 158 159 // A and B are either all quantized or all not quantized. 160 assert(!((bool)aQType ^ (bool)bQType) && 161 "Matmul operands must be all quantized or all not quantized"); 162 163 if (aQType) { 164 165 int64_t aZp = aQType.getZeroPoint(); 166 int64_t bZp = bQType.getZeroPoint(); 167 168 auto quantAttr = tosa::MatMulOpQuantizationAttr::get( 169 builder.getI32IntegerAttr(aZp), builder.getI32IntegerAttr(bZp), 170 builder.getContext()); 171 172 return quantAttr; 173 } 174 175 return nullptr; 176 } 177 178 /// Builds UnaryOpQuantizationAttr 179 /// UnaryOpQuantInfoBuilder: 180 /// inputZp: input zeropoint 181 /// outputZp: output zeropoint. 182 UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, 183 Value input, 184 Type outputRawType) { 185 186 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 187 auto outputType = outputRawType.dyn_cast<RankedTensorType>(); 188 189 if (!inputType || !outputType) 190 return nullptr; 191 192 auto inputQType = GET_UQTYPE(inputType); 193 auto outputQType = GET_UQTYPE(outputType); 194 195 // Either all quantized or all not quantized. 196 assert(!((bool)inputQType ^ (bool)outputQType) && 197 "Unary inputs/outputs must be all quantized or all not quantized"); 198 199 if (inputQType) { 200 201 int64_t inputZp = inputQType.getZeroPoint(); 202 int64_t outputZp = outputQType.getZeroPoint(); 203 204 auto quantAttr = tosa::UnaryOpQuantizationAttr::get( 205 builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(outputZp), 206 builder.getContext()); 207 208 return quantAttr; 209 } 210 211 return nullptr; 212 } 213 214 /// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: 215 /// inputZp: input zeropoint. 216 PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, 217 Value input) { 218 219 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 220 221 if (!inputType) 222 return nullptr; 223 224 auto inputQType = GET_UQTYPE(inputType); 225 226 if (inputQType) { 227 228 int64_t inputZp = inputQType.getZeroPoint(); 229 230 auto quantAttr = tosa::PadOpQuantizationAttr::get( 231 builder.getI32IntegerAttr(inputZp), builder.getContext()); 232 233 return quantAttr; 234 } 235 236 return nullptr; 237 } 238 239 /// Builds output type for a quantized ConvOp with the right bitwidth. 240 /// This is called by the builder when dealing with quantized content. 241 Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, 242 Value weight) { 243 244 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 245 auto weightType = weight.getType().dyn_cast<RankedTensorType>(); 246 247 assert(inputType && weightType && 248 "Could not extract input or weight tensors from Conv op"); 249 250 auto inputQType = GET_QTYPE(inputType); 251 auto weightQType = GET_QTYPE(weightType); 252 253 assert(inputQType && weightQType && 254 "Could not extract input or weight tensor types from Conv op"); 255 256 unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); 257 unsigned weightBits = weightQType.getStorageTypeIntegralWidth(); 258 259 auto outputShapedType = outputType.dyn_cast<RankedTensorType>(); 260 assert(outputShapedType && 261 "Could not extract output shape type from Conv op"); 262 263 auto outputShape = outputShapedType.getShape(); 264 265 IntegerType accElementType; 266 if (inputBits == 16 && weightBits == 8) 267 accElementType = builder.getIntegerType(48); 268 else 269 accElementType = builder.getI32Type(); 270 auto accType = RankedTensorType::get(outputShape, accElementType); 271 return accType; 272 } 273 274 /// Builds Tosa quantization attributes from min/max values. 275 Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, 276 Attribute maxAttr, IntegerAttr quantBits, 277 int filterQuantDim, bool isSigned, 278 BoolAttr narrowRange) { 279 280 quant::QuantizedType retType; 281 282 auto convfunc = 283 quant::ExpressedToQuantizedConverter::forInputType(inputDType); 284 285 auto minElems = minAttr.dyn_cast<DenseFPElementsAttr>(); 286 auto maxElems = maxAttr.dyn_cast<DenseFPElementsAttr>(); 287 288 SmallVector<double, 2> min, max; 289 290 // At least one is per-axis quantized elementsattr. 291 if (minElems || maxElems) { 292 // Must have the same number of elements. 293 if (minElems.getNumElements() != maxElems.getNumElements()) 294 return {}; 295 min.reserve(minElems.getNumElements()); 296 max.reserve(maxElems.getNumElements()); 297 for (auto i : minElems) 298 min.push_back(FloatAttr::getValueAsDouble(i)); 299 for (auto i : maxElems) 300 max.push_back(FloatAttr::getValueAsDouble(i)); 301 } else { // Just a single FP value. 302 auto minVal = minAttr.dyn_cast<FloatAttr>(); 303 if (minVal) 304 min.push_back(minVal.getValueAsDouble()); 305 else 306 return {}; 307 auto maxVal = maxAttr.dyn_cast<FloatAttr>(); 308 if (maxVal) 309 max.push_back(maxVal.getValueAsDouble()); 310 else 311 return {}; 312 } 313 314 if (min.size() == max.size()) { 315 if (min.size() == 1) { // Per-tensor quantization with one min/max pair. 316 retType = quant::fakeQuantAttrsToType( 317 builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], 318 narrowRange.getValue(), convfunc.expressedType, isSigned); 319 } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. 320 auto shape = inputDType.dyn_cast<ShapedType>(); 321 if (!shape) 322 return {}; 323 if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { 324 retType = quant::fakeQuantAttrsToType( 325 builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0], 326 max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); 327 } 328 } else { 329 return {}; 330 } 331 } else { 332 return {}; 333 } 334 335 if (!retType) 336 return {}; 337 338 return convfunc.convert(retType); 339 } 340 341 /// Builds Tosa quantization attributes from min/max values. 342 TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, 343 Attribute minAttr, Attribute maxAttr, 344 IntegerAttr quantBits, int filterQuantDim, 345 bool isSigned, BoolAttr narrowRange) { 346 347 return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr, 348 maxAttr, quantBits, filterQuantDim, 349 isSigned, narrowRange)); 350 } 351