xref: /llvm-project/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp (revision 3745e7080746b73377a479b6ceba2dbf25f245e2)
1 //===- QuantUtils.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains TOSA numerical support functions and quantization
10 // attribute builders.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
15 
16 using namespace mlir;
17 using namespace mlir::tosa;
18 
19 /// From a scale value, generates multiplier and shift values where
20 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
21 /// multiplier = mantissa*2^shift for 16-bit scaling.
computeMultiplierAndShiftTosaScale16(double scale,int32_t & multiplier,int32_t & shift)22 static void computeMultiplierAndShiftTosaScale16(double scale,
23                                                  int32_t &multiplier,
24                                                  int32_t &shift) {
25 
26   const double mantissa = std::frexp(scale, &shift);
27   auto shiftedM = std::round(mantissa * (int64_t(1) << 15));
28 
29   // Can't be greater than 1.0.
30   assert(shiftedM <= (int64_t(1) << 15) &&
31          "Shifted mantissa exceeds 16 signed bits");
32 
33   if (shiftedM == (int64_t(1) << 15)) {
34     shiftedM /= 2;
35     shift++;
36   }
37 
38   // TOSA expects right shift to be positive and embed (1 << 15) into right
39   // shift bits.
40   shift = (-shift) + 15;
41 
42   assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
43          "Shifted mantissa exceeds 32-bit signed output type");
44 
45   multiplier = static_cast<int32_t>(shiftedM);
46 
47   // Shifting tops out at 62 bits. Right shift to make 62 bits the max.
48   // The limit of 62 on shift allows the shift to be decomposed as
49   // two right shifts of 31.
50   if (shift > 62) {
51     // Shifting the multiplier by more than 31-bits is unnecessary.
52     multiplier = multiplier >> std::min<int32_t>(31, shift - 62);
53     shift = 62;
54   }
55 }
56 
57 /// From a scale value, generates multiplier and shift values where
58 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
59 /// multiplier = mantissa*2^shift for 32-bit scaling.
computeMultiplierAndShiftTosaScale32(double scale,int32_t & multiplier,int32_t & shift)60 static void computeMultiplierAndShiftTosaScale32(double scale,
61                                                  int32_t &multiplier,
62                                                  int32_t &shift) {
63 
64   const double mantissa = std::frexp(scale, &shift);
65   auto shiftedM = std::round(mantissa * (int64_t(1) << 31));
66 
67   // Can't be greater than 1.0.
68   assert(shiftedM <= (int64_t(1) << 31) &&
69          "Shifted mantissa exceeds 32 signed bits");
70   if (shiftedM == (int64_t(1) << 31)) {
71     shiftedM /= 2;
72     shift++;
73   }
74 
75   // TOSA expects right shift to be positive, and embed (1 << 31) into right
76   // shift bits.
77   shift = (-shift) + 31;
78 
79   assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
80          "Shifted mantissa exceeds 32-bit signed output type");
81 
82   multiplier = static_cast<int32_t>(shiftedM);
83 
84   // Shifting tops out at 62 bits. Right shift to make 62 bits the max.
85   // The limit of 62 on shift allows the shift to be decomposed as
86   // two right shifts of 31.
87   if (shift > 62) {
88     // Shifting the multiplier by more than 32-bits is unnecessary.
89     multiplier = multiplier >> std::min<int32_t>(31, shift - 62);
90     shift = 62;
91   }
92 }
93 
94 /// Generates a quantized multiplier/shift from double.
computeMultiplierAndShift(double scale,int32_t & multiplier,int32_t & shift,int32_t scaleWidth)95 void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier,
96                                            int32_t &shift, int32_t scaleWidth) {
97 
98   switch (scaleWidth) {
99   case 16:
100     computeMultiplierAndShiftTosaScale16(scale, multiplier, shift);
101     return;
102   case 32:
103     computeMultiplierAndShiftTosaScale32(scale, multiplier, shift);
104     return;
105   default:
106     assert(0 && "Unsupported Tosa quantized_scale regime specified!");
107   }
108 }
109 
110 #define GET_UQTYPE(inputType)                                                  \
111   (llvm::dyn_cast<quant::UniformQuantizedType>((inputType).getElementType()))
112 #define GET_QTYPE(inputType)                                                   \
113   (llvm::dyn_cast<quant::QuantizedType>((inputType).getElementType()))
114 
115 /// Method to build ConvOpQuantizationAttr, called from
116 /// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:
117 /// input_zp: input zeropoint
118 /// weight_zp: weight zeropoint.
119 ConvOpQuantizationAttr
buildConvOpQuantizationAttr(OpBuilder & builder,Value input,Value weight)120 mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
121                                         Value weight) {
122 
123   auto inputType = dyn_cast<ShapedType>(input.getType());
124   auto weightType = dyn_cast<ShapedType>(weight.getType());
125 
126   if (!inputType || !weightType)
127     return nullptr;
128 
129   auto inputQType = GET_UQTYPE(inputType);
130   auto weightPerTensorQType = GET_UQTYPE(weightType);
131   auto weightPerAxisQType =
132       dyn_cast<quant::UniformQuantizedPerAxisType>(weightType.getElementType());
133 
134   // Weights must be either per-tensor quantized or per-axis quantized.
135   assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) &&
136          "Weights must be either per-tensor or per-axis quantized");
137 
138   // Either all quantized or all not quantized.
139   assert(!((bool)inputQType ^
140            ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) &&
141          "Inputs and weights must be all quantized or all not quantized");
142 
143   if (inputQType) {
144     int64_t inputZp = inputQType.getZeroPoint();
145     int64_t weightZp = 0;
146 
147     if (weightPerTensorQType) {
148       weightZp = weightPerTensorQType.getZeroPoint();
149     } else if (weightPerAxisQType) {
150       weightZp = weightPerAxisQType.getZeroPoints().front();
151     }
152 
153     return builder.getAttr<tosa::ConvOpQuantizationAttr>(inputZp, weightZp);
154   }
155 
156   return nullptr;
157 }
158 
159 /// Builds MatMulOpQuantizationAttr, called from
160 /// MatMulOpQuantInfoBuilder:
161 /// aZp: input a zeropoint
162 /// bZp: input b zeropoint.
163 MatMulOpQuantizationAttr
buildMatMulOpQuantizationAttr(OpBuilder & builder,Value a,Value b)164 mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a,
165                                           Value b) {
166 
167   auto aType = dyn_cast<ShapedType>(a.getType());
168   auto bType = dyn_cast<ShapedType>(b.getType());
169 
170   if (!aType || !bType)
171     return nullptr;
172 
173   auto aQType = GET_UQTYPE(aType);
174   auto bQType = GET_UQTYPE(bType);
175 
176   // A and B are either all quantized or all not quantized.
177   assert(!((bool)aQType ^ (bool)bQType) &&
178          "Matmul operands must be all quantized or all not quantized");
179 
180   if (aQType) {
181     return builder.getAttr<tosa::MatMulOpQuantizationAttr>(
182         aQType.getZeroPoint(), bQType.getZeroPoint());
183   }
184 
185   return nullptr;
186 }
187 
188 /// Builds UnaryOpQuantizationAttr
189 /// UnaryOpQuantInfoBuilder:
190 /// inputZp: input zeropoint
191 /// outputZp: output zeropoint.
192 UnaryOpQuantizationAttr
buildUnaryOpQuantizationAttr(OpBuilder & builder,Value input,Type outputRawType)193 mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
194                                          Type outputRawType) {
195 
196   auto inputType = dyn_cast<ShapedType>(input.getType());
197   auto outputType = dyn_cast<ShapedType>(outputRawType);
198 
199   if (!inputType || !outputType)
200     return nullptr;
201 
202   auto inputQType = GET_UQTYPE(inputType);
203   auto outputQType = GET_UQTYPE(outputType);
204 
205   // Either all quantized or all not quantized.
206   assert(!((bool)inputQType ^ (bool)outputQType) &&
207          "Unary inputs/outputs must be all quantized or all not quantized");
208 
209   if (inputQType) {
210     return builder.getAttr<UnaryOpQuantizationAttr>(inputQType.getZeroPoint(),
211                                                     outputQType.getZeroPoint());
212   }
213 
214   return nullptr;
215 }
216 
217 /// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder:
218 /// inputZp: input zeropoint.
buildPadOpQuantizationAttr(OpBuilder & builder,Value input)219 PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
220                                                              Value input) {
221 
222   auto inputType = dyn_cast<ShapedType>(input.getType());
223 
224   if (!inputType)
225     return nullptr;
226 
227   auto inputQType = GET_UQTYPE(inputType);
228 
229   if (inputQType) {
230     return builder.getAttr<tosa::PadOpQuantizationAttr>(
231         inputQType.getZeroPoint());
232   }
233 
234   return nullptr;
235 }
236 
237 /// Builds output type for a quantized ConvOp with the right bitwidth.
238 /// This is called by the builder when dealing with quantized content.
buildConvOpResultTypeInfo(OpBuilder & builder,Type outputType,Value input,Value weight)239 Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
240                                            Value input, Value weight) {
241 
242   auto inputType = dyn_cast<ShapedType>(input.getType());
243   auto weightType = dyn_cast<ShapedType>(weight.getType());
244 
245   assert(inputType && weightType &&
246          "Could not extract input or weight tensors from Conv op");
247 
248   auto inputQType = GET_QTYPE(inputType);
249   auto weightQType = GET_QTYPE(weightType);
250 
251   assert(inputQType && weightQType &&
252          "Could not extract input or weight tensor types from Conv op");
253 
254   unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
255   unsigned weightBits = weightQType.getStorageTypeIntegralWidth();
256 
257   auto outputShapedType = dyn_cast<ShapedType>(outputType);
258   assert(outputShapedType &&
259          "Could not extract output shape type from Conv op");
260 
261   IntegerType accElementType;
262   if (inputBits == 16 && weightBits == 8)
263     accElementType = builder.getIntegerType(48);
264   else
265     accElementType = builder.getI32Type();
266   auto accType = outputShapedType.clone(accElementType);
267   return accType;
268 }
269 
270 /// Builds Tosa quantization attributes from min/max values.
buildQTypeFromMinMax(OpBuilder builder,Type inputDType,Attribute minAttr,Attribute maxAttr,IntegerAttr quantBits,int filterQuantDim,bool isSigned,BoolAttr narrowRange)271 Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType,
272                                       Attribute minAttr, Attribute maxAttr,
273                                       IntegerAttr quantBits, int filterQuantDim,
274                                       bool isSigned, BoolAttr narrowRange) {
275 
276   quant::QuantizedType retType;
277 
278   auto convfunc =
279       quant::ExpressedToQuantizedConverter::forInputType(inputDType);
280 
281   auto minElems = dyn_cast<DenseFPElementsAttr>(minAttr);
282   auto maxElems = dyn_cast<DenseFPElementsAttr>(maxAttr);
283 
284   SmallVector<double, 2> min, max;
285 
286   // At least one is per-axis quantized elementsattr.
287   if (minElems || maxElems) {
288     // Must have the same number of elements.
289     if (minElems.getNumElements() != maxElems.getNumElements())
290       return {};
291     min.reserve(minElems.getNumElements());
292     max.reserve(maxElems.getNumElements());
293     for (auto i : minElems)
294       min.push_back(FloatAttr::getValueAsDouble(i));
295     for (auto i : maxElems)
296       max.push_back(FloatAttr::getValueAsDouble(i));
297   } else { // Just a single FP value.
298     auto minVal = dyn_cast<FloatAttr>(minAttr);
299     if (minVal)
300       min.push_back(minVal.getValueAsDouble());
301     else
302       return {};
303     auto maxVal = dyn_cast<FloatAttr>(maxAttr);
304     if (maxVal)
305       max.push_back(maxVal.getValueAsDouble());
306     else
307       return {};
308   }
309 
310   if (min.size() == max.size()) {
311     if (min.size() == 1) { // Per-tensor quantization with one min/max pair.
312       retType = quant::fakeQuantAttrsToType(
313           builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0],
314           narrowRange.getValue(), convfunc.expressedType, isSigned);
315     } else if (min.size() > 1) { // Per-axis quant on filterQuantDim.
316       auto shape = dyn_cast<ShapedType>(inputDType);
317       if (!shape)
318         return {};
319       if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) {
320         retType = quant::fakeQuantAttrsToType(
321             builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0],
322             max[0], narrowRange.getValue(), convfunc.expressedType, isSigned);
323       }
324     } else {
325       return {};
326     }
327   } else {
328     return {};
329   }
330 
331   if (!retType)
332     return {};
333 
334   return convfunc.convert(retType);
335 }
336 
337 /// Builds Tosa quantization attributes from min/max values.
338 TypeAttr
buildQTypeAttrFromMinMax(OpBuilder builder,Type inputDtype,Attribute minAttr,Attribute maxAttr,IntegerAttr quantBits,int filterQuantDim,bool isSigned,BoolAttr narrowRange)339 mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
340                                      Attribute minAttr, Attribute maxAttr,
341                                      IntegerAttr quantBits, int filterQuantDim,
342                                      bool isSigned, BoolAttr narrowRange) {
343 
344   return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr,
345                                             maxAttr, quantBits, filterQuantDim,
346                                             isSigned, narrowRange));
347 }
348