xref: /llvm-project/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp (revision b28121133d8cc4c57cc086b94e1248e7a2555465)
1 //===- QuantUtils.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains TOSA numerical support functions and quantization
10 // attribute builders.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
15 
16 using namespace mlir;
17 using namespace mlir::tosa;
18 
19 /// From a scale value, generates multiplier and shift values where
20 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
21 /// multiplier = mantissa*2^shift for 16-bit scaling.
22 void computeMultiplierAndShiftTosaScale16(double scale, int32_t &multiplier,
23                                           int32_t &shift) {
24 
25   const double mantissa = std::frexp(scale, &shift);
26   auto shiftedM = std::round(mantissa * (int64_t(1) << 15));
27 
28   // Can't be greater than 1.0.
29   assert(shiftedM <= (int64_t(1) << 15) &&
30          "Shifted mantissa exceeds 16 signed bits");
31 
32   if (shiftedM == (int64_t(1) << 15)) {
33     shiftedM /= 2;
34     shift++;
35   }
36 
37   // TOSA expects right shift to be positive and embed (1 << 15) into right
38   // shift bits.
39   shift = (-shift) + 15;
40 
41   assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
42          "Shifted mantissa exceeds 32-bit signed output type");
43 
44   multiplier = static_cast<int32_t>(shiftedM);
45 }
46 
47 /// From a scale value, generates multiplier and shift values where
48 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
49 /// multiplier = mantissa*2^shift for 32-bit scaling.
50 void computeMultiplierAndShiftTosaScale32(double scale, int32_t &multiplier,
51                                           int32_t &shift) {
52 
53   const double mantissa = std::frexp(scale, &shift);
54   auto shiftedM = std::round(mantissa * (int64_t(1) << 31));
55 
56   // Can't be greater than 1.0.
57   assert(shiftedM <= (int64_t(1) << 31) &&
58          "Shifted mantissa exceeds 32 signed bits");
59   if (shiftedM == (int64_t(1) << 31)) {
60     shiftedM /= 2;
61     shift++;
62   }
63 
64   // TOSA expects right shift to be positive, and embed (1 << 31) into right
65   // shift bits.
66   shift = (-shift) + 31;
67 
68   assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
69          "Shifted mantissa exceeds 32-bit signed output type");
70 
71   multiplier = static_cast<int32_t>(shiftedM);
72 }
73 
74 /// Generates a quantized multiplier/shift from double.
75 void computeMultiplierAndShift(double scale, int32_t &multiplier,
76                                int32_t &shift, int32_t scaleWidth) {
77 
78   switch (scaleWidth) {
79   case 16:
80     computeMultiplierAndShiftTosaScale16(scale, multiplier, shift);
81     return;
82   case 32:
83     computeMultiplierAndShiftTosaScale32(scale, multiplier, shift);
84     return;
85   default:
86     assert(0 && "Unsupported Tosa quantized_scale regime specified!");
87   }
88 }
89 
90 #define GET_UQTYPE(input_type)                                                 \
91   ((input_type).getElementType().dyn_cast<quant::UniformQuantizedType>())
92 #define GET_QTYPE(input_type)                                                  \
93   ((input_type).getElementType().dyn_cast<quant::QuantizedType>())
94 
95 /// Method to build ConvOpQuantizationAttr, called from
96 /// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:
97 /// input_zp: input zeropoint
98 /// weight_zp: weight zeropoint.
99 ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
100                                                    Value input, Value weight) {
101 
102   auto inputType = input.getType().dyn_cast<RankedTensorType>();
103   auto weightType = weight.getType().dyn_cast<RankedTensorType>();
104 
105   if (!inputType || !weightType)
106     return nullptr;
107 
108   auto inputQType = GET_UQTYPE(inputType);
109   auto weightPerTensorQType = GET_UQTYPE(weightType);
110   auto weightPerAxisQType = weightType.getElementType()
111                                 .dyn_cast<quant::UniformQuantizedPerAxisType>();
112 
113   // Weights must be either per-tensor quantized or per-axis quantized.
114   assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) &&
115          "Weights must be either per-tensor or per-axis quantized");
116 
117   // Either all quantized or all not quantized.
118   assert(!((bool)inputQType ^
119            ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) &&
120          "Inputs and weights must be all quantized or all not quantized");
121 
122   if (inputQType) {
123 
124     int64_t inputZp = inputQType.getZeroPoint();
125     int64_t weightZp = 0;
126 
127     if (weightPerTensorQType) {
128       weightZp = weightPerTensorQType.getZeroPoint();
129     } else if (weightPerAxisQType) {
130       weightZp = weightPerAxisQType.getZeroPoints().front();
131     }
132 
133     auto quantAttr = tosa::ConvOpQuantizationAttr::get(
134         builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(weightZp),
135         builder.getContext());
136 
137     return quantAttr;
138   }
139 
140   return nullptr;
141 }
142 
143 /// Builds MatMulOpQuantizationAttr, called from
144 /// MatMulOpQuantInfoBuilder:
145 /// aZp: input a zeropoint
146 /// bZp: input b zeropoint.
147 MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder,
148                                                        Value a, Value b) {
149 
150   auto aType = a.getType().dyn_cast<RankedTensorType>();
151   auto bType = b.getType().dyn_cast<RankedTensorType>();
152 
153   if (!aType || !bType)
154     return nullptr;
155 
156   auto aQType = GET_UQTYPE(aType);
157   auto bQType = GET_UQTYPE(bType);
158 
159   // A and B are either all quantized or all not quantized.
160   assert(!((bool)aQType ^ (bool)bQType) &&
161          "Matmul operands must be all quantized or all not quantized");
162 
163   if (aQType) {
164 
165     int64_t aZp = aQType.getZeroPoint();
166     int64_t bZp = bQType.getZeroPoint();
167 
168     auto quantAttr = tosa::MatMulOpQuantizationAttr::get(
169         builder.getI32IntegerAttr(aZp), builder.getI32IntegerAttr(bZp),
170         builder.getContext());
171 
172     return quantAttr;
173   }
174 
175   return nullptr;
176 }
177 
178 /// Builds UnaryOpQuantizationAttr
179 /// UnaryOpQuantInfoBuilder:
180 /// inputZp: input zeropoint
181 /// outputZp: output zeropoint.
182 UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder,
183                                                      Value input,
184                                                      Type outputRawType) {
185 
186   auto inputType = input.getType().dyn_cast<RankedTensorType>();
187   auto outputType = outputRawType.dyn_cast<RankedTensorType>();
188 
189   if (!inputType || !outputType)
190     return nullptr;
191 
192   auto inputQType = GET_UQTYPE(inputType);
193   auto outputQType = GET_UQTYPE(outputType);
194 
195   // Either all quantized or all not quantized.
196   assert(!((bool)inputQType ^ (bool)outputQType) &&
197          "Unary inputs/outputs must be all quantized or all not quantized");
198 
199   if (inputQType) {
200 
201     int64_t inputZp = inputQType.getZeroPoint();
202     int64_t outputZp = outputQType.getZeroPoint();
203 
204     auto quantAttr = tosa::UnaryOpQuantizationAttr::get(
205         builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(outputZp),
206         builder.getContext());
207 
208     return quantAttr;
209   }
210 
211   return nullptr;
212 }
213 
214 /// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder:
215 /// inputZp: input zeropoint.
216 PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder,
217                                                  Value input) {
218 
219   auto inputType = input.getType().dyn_cast<RankedTensorType>();
220 
221   if (!inputType)
222     return nullptr;
223 
224   auto inputQType = GET_UQTYPE(inputType);
225 
226   if (inputQType) {
227 
228     int64_t inputZp = inputQType.getZeroPoint();
229 
230     auto quantAttr = tosa::PadOpQuantizationAttr::get(
231         builder.getI32IntegerAttr(inputZp), builder.getContext());
232 
233     return quantAttr;
234   }
235 
236   return nullptr;
237 }
238 
239 /// Builds output type for a quantized ConvOp with the right bitwidth.
240 /// This is called by the builder when dealing with quantized content.
241 Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input,
242                                Value weight) {
243 
244   auto inputType = input.getType().dyn_cast<RankedTensorType>();
245   auto weightType = weight.getType().dyn_cast<RankedTensorType>();
246 
247   assert(inputType && weightType &&
248          "Could not extract input or weight tensors from Conv op");
249 
250   auto inputQType = GET_QTYPE(inputType);
251   auto weightQType = GET_QTYPE(weightType);
252 
253   assert(inputQType && weightQType &&
254          "Could not extract input or weight tensor types from Conv op");
255 
256   unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
257   unsigned weightBits = weightQType.getStorageTypeIntegralWidth();
258 
259   auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
260   assert(outputShapedType &&
261          "Could not extract output shape type from Conv op");
262 
263   auto outputShape = outputShapedType.getShape();
264 
265   IntegerType accElementType;
266   if (inputBits == 16 && weightBits == 8)
267     accElementType = builder.getIntegerType(48);
268   else
269     accElementType = builder.getI32Type();
270   auto accType = RankedTensorType::get(outputShape, accElementType);
271   return accType;
272 }
273 
274 /// Builds Tosa quantization attributes from min/max values.
275 Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr,
276                           Attribute maxAttr, IntegerAttr quantBits,
277                           int filterQuantDim, bool isSigned,
278                           BoolAttr narrowRange) {
279 
280   quant::QuantizedType retType;
281 
282   auto convfunc =
283       quant::ExpressedToQuantizedConverter::forInputType(inputDType);
284 
285   auto minElems = minAttr.dyn_cast<DenseFPElementsAttr>();
286   auto maxElems = maxAttr.dyn_cast<DenseFPElementsAttr>();
287 
288   SmallVector<double, 2> min, max;
289 
290   // At least one is per-axis quantized elementsattr.
291   if (minElems || maxElems) {
292     // Must have the same number of elements.
293     if (minElems.getNumElements() != maxElems.getNumElements())
294       return {};
295     min.reserve(minElems.getNumElements());
296     max.reserve(maxElems.getNumElements());
297     for (auto i : minElems)
298       min.push_back(FloatAttr::getValueAsDouble(i));
299     for (auto i : maxElems)
300       max.push_back(FloatAttr::getValueAsDouble(i));
301   } else { // Just a single FP value.
302     auto minVal = minAttr.dyn_cast<FloatAttr>();
303     if (minVal)
304       min.push_back(minVal.getValueAsDouble());
305     else
306       return {};
307     auto maxVal = maxAttr.dyn_cast<FloatAttr>();
308     if (maxVal)
309       max.push_back(maxVal.getValueAsDouble());
310     else
311       return {};
312   }
313 
314   if (min.size() == max.size()) {
315     if (min.size() == 1) { // Per-tensor quantization with one min/max pair.
316       retType = quant::fakeQuantAttrsToType(
317           builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0],
318           narrowRange.getValue(), convfunc.expressedType, isSigned);
319     } else if (min.size() > 1) { // Per-axis quant on filterQuantDim.
320       auto shape = inputDType.dyn_cast<ShapedType>();
321       if (!shape)
322         return {};
323       if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) {
324         retType = quant::fakeQuantAttrsToType(
325             builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0],
326             max[0], narrowRange.getValue(), convfunc.expressedType, isSigned);
327       }
328     } else {
329       return {};
330     }
331   } else {
332     return {};
333   }
334 
335   if (!retType)
336     return {};
337 
338   return convfunc.convert(retType);
339 }
340 
341 /// Builds Tosa quantization attributes from min/max values.
342 TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
343                                   Attribute minAttr, Attribute maxAttr,
344                                   IntegerAttr quantBits, int filterQuantDim,
345                                   bool isSigned, BoolAttr narrowRange) {
346 
347   return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr,
348                                             maxAttr, quantBits, filterQuantDim,
349                                             isSigned, narrowRange));
350 }
351