xref: /llvm-project/llvm/lib/Support/APFixedPoint.cpp (revision 16544cbe64b81a50800a88296ef37f4873a37b25)
1 //===- APFixedPoint.cpp - Fixed point constant handling ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Defines the implementation for the fixed point number interface.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/APFixedPoint.h"
15 #include "llvm/ADT/APFloat.h"
16 
17 #include <cmath>
18 
19 namespace llvm {
20 
21 APFixedPoint APFixedPoint::convert(const FixedPointSemantics &DstSema,
22                                    bool *Overflow) const {
23   APSInt NewVal = Val;
24   unsigned DstWidth = DstSema.getWidth();
25   unsigned DstScale = DstSema.getScale();
26   bool Upscaling = DstScale > getScale();
27   if (Overflow)
28     *Overflow = false;
29 
30   if (Upscaling) {
31     NewVal = NewVal.extend(NewVal.getBitWidth() + DstScale - getScale());
32     NewVal <<= (DstScale - getScale());
33   } else {
34     NewVal >>= (getScale() - DstScale);
35   }
36 
37   auto Mask = APInt::getBitsSetFrom(
38       NewVal.getBitWidth(),
39       std::min(DstScale + DstSema.getIntegralBits(), NewVal.getBitWidth()));
40   APInt Masked(NewVal & Mask);
41 
42   // Change in the bits above the sign
43   if (!(Masked == Mask || Masked == 0)) {
44     // Found overflow in the bits above the sign
45     if (DstSema.isSaturated())
46       NewVal = NewVal.isNegative() ? Mask : ~Mask;
47     else if (Overflow)
48       *Overflow = true;
49   }
50 
51   // If the dst semantics are unsigned, but our value is signed and negative, we
52   // clamp to zero.
53   if (!DstSema.isSigned() && NewVal.isSigned() && NewVal.isNegative()) {
54     // Found negative overflow for unsigned result
55     if (DstSema.isSaturated())
56       NewVal = 0;
57     else if (Overflow)
58       *Overflow = true;
59   }
60 
61   NewVal = NewVal.extOrTrunc(DstWidth);
62   NewVal.setIsSigned(DstSema.isSigned());
63   return APFixedPoint(NewVal, DstSema);
64 }
65 
66 int APFixedPoint::compare(const APFixedPoint &Other) const {
67   APSInt ThisVal = getValue();
68   APSInt OtherVal = Other.getValue();
69   bool ThisSigned = Val.isSigned();
70   bool OtherSigned = OtherVal.isSigned();
71   unsigned OtherScale = Other.getScale();
72   unsigned OtherWidth = OtherVal.getBitWidth();
73 
74   unsigned CommonWidth = std::max(Val.getBitWidth(), OtherWidth);
75 
76   // Prevent overflow in the event the widths are the same but the scales differ
77   CommonWidth += getScale() >= OtherScale ? getScale() - OtherScale
78                                           : OtherScale - getScale();
79 
80   ThisVal = ThisVal.extOrTrunc(CommonWidth);
81   OtherVal = OtherVal.extOrTrunc(CommonWidth);
82 
83   unsigned CommonScale = std::max(getScale(), OtherScale);
84   ThisVal = ThisVal.shl(CommonScale - getScale());
85   OtherVal = OtherVal.shl(CommonScale - OtherScale);
86 
87   if (ThisSigned && OtherSigned) {
88     if (ThisVal.sgt(OtherVal))
89       return 1;
90     else if (ThisVal.slt(OtherVal))
91       return -1;
92   } else if (!ThisSigned && !OtherSigned) {
93     if (ThisVal.ugt(OtherVal))
94       return 1;
95     else if (ThisVal.ult(OtherVal))
96       return -1;
97   } else if (ThisSigned && !OtherSigned) {
98     if (ThisVal.isSignBitSet())
99       return -1;
100     else if (ThisVal.ugt(OtherVal))
101       return 1;
102     else if (ThisVal.ult(OtherVal))
103       return -1;
104   } else {
105     // !ThisSigned && OtherSigned
106     if (OtherVal.isSignBitSet())
107       return 1;
108     else if (ThisVal.ugt(OtherVal))
109       return 1;
110     else if (ThisVal.ult(OtherVal))
111       return -1;
112   }
113 
114   return 0;
115 }
116 
117 APFixedPoint APFixedPoint::getMax(const FixedPointSemantics &Sema) {
118   bool IsUnsigned = !Sema.isSigned();
119   auto Val = APSInt::getMaxValue(Sema.getWidth(), IsUnsigned);
120   if (IsUnsigned && Sema.hasUnsignedPadding())
121     Val = Val.lshr(1);
122   return APFixedPoint(Val, Sema);
123 }
124 
125 APFixedPoint APFixedPoint::getMin(const FixedPointSemantics &Sema) {
126   auto Val = APSInt::getMinValue(Sema.getWidth(), !Sema.isSigned());
127   return APFixedPoint(Val, Sema);
128 }
129 
130 bool FixedPointSemantics::fitsInFloatSemantics(
131     const fltSemantics &FloatSema) const {
132   // A fixed point semantic fits in a floating point semantic if the maximum
133   // and minimum values as integers of the fixed point semantic can fit in the
134   // floating point semantic.
135 
136   // If these values do not fit, then a floating point rescaling of the true
137   // maximum/minimum value will not fit either, so the floating point semantic
138   // cannot be used to perform such a rescaling.
139 
140   APSInt MaxInt = APFixedPoint::getMax(*this).getValue();
141   APFloat F(FloatSema);
142   APFloat::opStatus Status = F.convertFromAPInt(MaxInt, MaxInt.isSigned(),
143                                                 APFloat::rmNearestTiesToAway);
144   if ((Status & APFloat::opOverflow) || !isSigned())
145     return !(Status & APFloat::opOverflow);
146 
147   APSInt MinInt = APFixedPoint::getMin(*this).getValue();
148   Status = F.convertFromAPInt(MinInt, MinInt.isSigned(),
149                               APFloat::rmNearestTiesToAway);
150   return !(Status & APFloat::opOverflow);
151 }
152 
153 FixedPointSemantics FixedPointSemantics::getCommonSemantics(
154     const FixedPointSemantics &Other) const {
155   unsigned CommonScale = std::max(getScale(), Other.getScale());
156   unsigned CommonWidth =
157       std::max(getIntegralBits(), Other.getIntegralBits()) + CommonScale;
158 
159   bool ResultIsSigned = isSigned() || Other.isSigned();
160   bool ResultIsSaturated = isSaturated() || Other.isSaturated();
161   bool ResultHasUnsignedPadding = false;
162   if (!ResultIsSigned) {
163     // Both are unsigned.
164     ResultHasUnsignedPadding = hasUnsignedPadding() &&
165                                Other.hasUnsignedPadding() && !ResultIsSaturated;
166   }
167 
168   // If the result is signed, add an extra bit for the sign. Otherwise, if it is
169   // unsigned and has unsigned padding, we only need to add the extra padding
170   // bit back if we are not saturating.
171   if (ResultIsSigned || ResultHasUnsignedPadding)
172     CommonWidth++;
173 
174   return FixedPointSemantics(CommonWidth, CommonScale, ResultIsSigned,
175                              ResultIsSaturated, ResultHasUnsignedPadding);
176 }
177 
178 APFixedPoint APFixedPoint::add(const APFixedPoint &Other,
179                                bool *Overflow) const {
180   auto CommonFXSema = Sema.getCommonSemantics(Other.getSemantics());
181   APFixedPoint ConvertedThis = convert(CommonFXSema);
182   APFixedPoint ConvertedOther = Other.convert(CommonFXSema);
183   APSInt ThisVal = ConvertedThis.getValue();
184   APSInt OtherVal = ConvertedOther.getValue();
185   bool Overflowed = false;
186 
187   APSInt Result;
188   if (CommonFXSema.isSaturated()) {
189     Result = CommonFXSema.isSigned() ? ThisVal.sadd_sat(OtherVal)
190                                      : ThisVal.uadd_sat(OtherVal);
191   } else {
192     Result = ThisVal.isSigned() ? ThisVal.sadd_ov(OtherVal, Overflowed)
193                                 : ThisVal.uadd_ov(OtherVal, Overflowed);
194   }
195 
196   if (Overflow)
197     *Overflow = Overflowed;
198 
199   return APFixedPoint(Result, CommonFXSema);
200 }
201 
202 APFixedPoint APFixedPoint::sub(const APFixedPoint &Other,
203                                bool *Overflow) const {
204   auto CommonFXSema = Sema.getCommonSemantics(Other.getSemantics());
205   APFixedPoint ConvertedThis = convert(CommonFXSema);
206   APFixedPoint ConvertedOther = Other.convert(CommonFXSema);
207   APSInt ThisVal = ConvertedThis.getValue();
208   APSInt OtherVal = ConvertedOther.getValue();
209   bool Overflowed = false;
210 
211   APSInt Result;
212   if (CommonFXSema.isSaturated()) {
213     Result = CommonFXSema.isSigned() ? ThisVal.ssub_sat(OtherVal)
214                                      : ThisVal.usub_sat(OtherVal);
215   } else {
216     Result = ThisVal.isSigned() ? ThisVal.ssub_ov(OtherVal, Overflowed)
217                                 : ThisVal.usub_ov(OtherVal, Overflowed);
218   }
219 
220   if (Overflow)
221     *Overflow = Overflowed;
222 
223   return APFixedPoint(Result, CommonFXSema);
224 }
225 
226 APFixedPoint APFixedPoint::mul(const APFixedPoint &Other,
227                                bool *Overflow) const {
228   auto CommonFXSema = Sema.getCommonSemantics(Other.getSemantics());
229   APFixedPoint ConvertedThis = convert(CommonFXSema);
230   APFixedPoint ConvertedOther = Other.convert(CommonFXSema);
231   APSInt ThisVal = ConvertedThis.getValue();
232   APSInt OtherVal = ConvertedOther.getValue();
233   bool Overflowed = false;
234 
235   // Widen the LHS and RHS so we can perform a full multiplication.
236   unsigned Wide = CommonFXSema.getWidth() * 2;
237   if (CommonFXSema.isSigned()) {
238     ThisVal = ThisVal.sext(Wide);
239     OtherVal = OtherVal.sext(Wide);
240   } else {
241     ThisVal = ThisVal.zext(Wide);
242     OtherVal = OtherVal.zext(Wide);
243   }
244 
245   // Perform the full multiplication and downscale to get the same scale.
246   //
247   // Note that the right shifts here perform an implicit downwards rounding.
248   // This rounding could discard bits that would technically place the result
249   // outside the representable range. We interpret the spec as allowing us to
250   // perform the rounding step first, avoiding the overflow case that would
251   // arise.
252   APSInt Result;
253   if (CommonFXSema.isSigned())
254     Result = ThisVal.smul_ov(OtherVal, Overflowed)
255                     .ashr(CommonFXSema.getScale());
256   else
257     Result = ThisVal.umul_ov(OtherVal, Overflowed)
258                     .lshr(CommonFXSema.getScale());
259   assert(!Overflowed && "Full multiplication cannot overflow!");
260   Result.setIsSigned(CommonFXSema.isSigned());
261 
262   // If our result lies outside of the representative range of the common
263   // semantic, we either have overflow or saturation.
264   APSInt Max = APFixedPoint::getMax(CommonFXSema).getValue()
265                                                  .extOrTrunc(Wide);
266   APSInt Min = APFixedPoint::getMin(CommonFXSema).getValue()
267                                                  .extOrTrunc(Wide);
268   if (CommonFXSema.isSaturated()) {
269     if (Result < Min)
270       Result = Min;
271     else if (Result > Max)
272       Result = Max;
273   } else
274     Overflowed = Result < Min || Result > Max;
275 
276   if (Overflow)
277     *Overflow = Overflowed;
278 
279   return APFixedPoint(Result.sextOrTrunc(CommonFXSema.getWidth()),
280                       CommonFXSema);
281 }
282 
283 APFixedPoint APFixedPoint::div(const APFixedPoint &Other,
284                                bool *Overflow) const {
285   auto CommonFXSema = Sema.getCommonSemantics(Other.getSemantics());
286   APFixedPoint ConvertedThis = convert(CommonFXSema);
287   APFixedPoint ConvertedOther = Other.convert(CommonFXSema);
288   APSInt ThisVal = ConvertedThis.getValue();
289   APSInt OtherVal = ConvertedOther.getValue();
290   bool Overflowed = false;
291 
292   // Widen the LHS and RHS so we can perform a full division.
293   unsigned Wide = CommonFXSema.getWidth() * 2;
294   if (CommonFXSema.isSigned()) {
295     ThisVal = ThisVal.sext(Wide);
296     OtherVal = OtherVal.sext(Wide);
297   } else {
298     ThisVal = ThisVal.zext(Wide);
299     OtherVal = OtherVal.zext(Wide);
300   }
301 
302   // Upscale to compensate for the loss of precision from division, and
303   // perform the full division.
304   ThisVal = ThisVal.shl(CommonFXSema.getScale());
305   APSInt Result;
306   if (CommonFXSema.isSigned()) {
307     APInt Rem;
308     APInt::sdivrem(ThisVal, OtherVal, Result, Rem);
309     // If the quotient is negative and the remainder is nonzero, round
310     // towards negative infinity by subtracting epsilon from the result.
311     if (ThisVal.isNegative() != OtherVal.isNegative() && !Rem.isZero())
312       Result = Result - 1;
313   } else
314     Result = ThisVal.udiv(OtherVal);
315   Result.setIsSigned(CommonFXSema.isSigned());
316 
317   // If our result lies outside of the representative range of the common
318   // semantic, we either have overflow or saturation.
319   APSInt Max = APFixedPoint::getMax(CommonFXSema).getValue()
320                                                  .extOrTrunc(Wide);
321   APSInt Min = APFixedPoint::getMin(CommonFXSema).getValue()
322                                                  .extOrTrunc(Wide);
323   if (CommonFXSema.isSaturated()) {
324     if (Result < Min)
325       Result = Min;
326     else if (Result > Max)
327       Result = Max;
328   } else
329     Overflowed = Result < Min || Result > Max;
330 
331   if (Overflow)
332     *Overflow = Overflowed;
333 
334   return APFixedPoint(Result.sextOrTrunc(CommonFXSema.getWidth()),
335                       CommonFXSema);
336 }
337 
338 APFixedPoint APFixedPoint::shl(unsigned Amt, bool *Overflow) const {
339   APSInt ThisVal = Val;
340   bool Overflowed = false;
341 
342   // Widen the LHS.
343   unsigned Wide = Sema.getWidth() * 2;
344   if (Sema.isSigned())
345     ThisVal = ThisVal.sext(Wide);
346   else
347     ThisVal = ThisVal.zext(Wide);
348 
349   // Clamp the shift amount at the original width, and perform the shift.
350   Amt = std::min(Amt, ThisVal.getBitWidth());
351   APSInt Result = ThisVal << Amt;
352   Result.setIsSigned(Sema.isSigned());
353 
354   // If our result lies outside of the representative range of the
355   // semantic, we either have overflow or saturation.
356   APSInt Max = APFixedPoint::getMax(Sema).getValue().extOrTrunc(Wide);
357   APSInt Min = APFixedPoint::getMin(Sema).getValue().extOrTrunc(Wide);
358   if (Sema.isSaturated()) {
359     if (Result < Min)
360       Result = Min;
361     else if (Result > Max)
362       Result = Max;
363   } else
364     Overflowed = Result < Min || Result > Max;
365 
366   if (Overflow)
367     *Overflow = Overflowed;
368 
369   return APFixedPoint(Result.sextOrTrunc(Sema.getWidth()), Sema);
370 }
371 
372 void APFixedPoint::toString(SmallVectorImpl<char> &Str) const {
373   APSInt Val = getValue();
374   unsigned Scale = getScale();
375 
376   if (Val.isSigned() && Val.isNegative() && Val != -Val) {
377     Val = -Val;
378     Str.push_back('-');
379   }
380 
381   APSInt IntPart = Val >> Scale;
382 
383   // Add 4 digits to hold the value after multiplying 10 (the radix)
384   unsigned Width = Val.getBitWidth() + 4;
385   APInt FractPart = Val.zextOrTrunc(Scale).zext(Width);
386   APInt FractPartMask = APInt::getAllOnes(Scale).zext(Width);
387   APInt RadixInt = APInt(Width, 10);
388 
389   IntPart.toString(Str, /*Radix=*/10);
390   Str.push_back('.');
391   do {
392     (FractPart * RadixInt)
393         .lshr(Scale)
394         .toString(Str, /*Radix=*/10, Val.isSigned());
395     FractPart = (FractPart * RadixInt) & FractPartMask;
396   } while (FractPart != 0);
397 }
398 
399 APFixedPoint APFixedPoint::negate(bool *Overflow) const {
400   if (!isSaturated()) {
401     if (Overflow)
402       *Overflow =
403           (!isSigned() && Val != 0) || (isSigned() && Val.isMinSignedValue());
404     return APFixedPoint(-Val, Sema);
405   }
406 
407   // We never overflow for saturation
408   if (Overflow)
409     *Overflow = false;
410 
411   if (isSigned())
412     return Val.isMinSignedValue() ? getMax(Sema) : APFixedPoint(-Val, Sema);
413   else
414     return APFixedPoint(Sema);
415 }
416 
417 APSInt APFixedPoint::convertToInt(unsigned DstWidth, bool DstSign,
418                                   bool *Overflow) const {
419   APSInt Result = getIntPart();
420   unsigned SrcWidth = getWidth();
421 
422   APSInt DstMin = APSInt::getMinValue(DstWidth, !DstSign);
423   APSInt DstMax = APSInt::getMaxValue(DstWidth, !DstSign);
424 
425   if (SrcWidth < DstWidth) {
426     Result = Result.extend(DstWidth);
427   } else if (SrcWidth > DstWidth) {
428     DstMin = DstMin.extend(SrcWidth);
429     DstMax = DstMax.extend(SrcWidth);
430   }
431 
432   if (Overflow) {
433     if (Result.isSigned() && !DstSign) {
434       *Overflow = Result.isNegative() || Result.ugt(DstMax);
435     } else if (Result.isUnsigned() && DstSign) {
436       *Overflow = Result.ugt(DstMax);
437     } else {
438       *Overflow = Result < DstMin || Result > DstMax;
439     }
440   }
441 
442   Result.setIsSigned(DstSign);
443   return Result.extOrTrunc(DstWidth);
444 }
445 
446 const fltSemantics *APFixedPoint::promoteFloatSemantics(const fltSemantics *S) {
447   if (S == &APFloat::BFloat())
448     return &APFloat::IEEEdouble();
449   else if (S == &APFloat::IEEEhalf())
450     return &APFloat::IEEEsingle();
451   else if (S == &APFloat::IEEEsingle())
452     return &APFloat::IEEEdouble();
453   else if (S == &APFloat::IEEEdouble())
454     return &APFloat::IEEEquad();
455   llvm_unreachable("Could not promote float type!");
456 }
457 
458 APFloat APFixedPoint::convertToFloat(const fltSemantics &FloatSema) const {
459   // For some operations, rounding mode has an effect on the result, while
460   // other operations are lossless and should never result in rounding.
461   // To signify which these operations are, we define two rounding modes here.
462   APFloat::roundingMode RM = APFloat::rmNearestTiesToEven;
463   APFloat::roundingMode LosslessRM = APFloat::rmTowardZero;
464 
465   // Make sure that we are operating in a type that works with this fixed-point
466   // semantic.
467   const fltSemantics *OpSema = &FloatSema;
468   while (!Sema.fitsInFloatSemantics(*OpSema))
469     OpSema = promoteFloatSemantics(OpSema);
470 
471   // Convert the fixed point value bits as an integer. If the floating point
472   // value does not have the required precision, we will round according to the
473   // given mode.
474   APFloat Flt(*OpSema);
475   APFloat::opStatus S = Flt.convertFromAPInt(Val, Sema.isSigned(), RM);
476 
477   // If we cared about checking for precision loss, we could look at this
478   // status.
479   (void)S;
480 
481   // Scale down the integer value in the float to match the correct scaling
482   // factor.
483   APFloat ScaleFactor(std::pow(2, -(int)Sema.getScale()));
484   bool Ignored;
485   ScaleFactor.convert(*OpSema, LosslessRM, &Ignored);
486   Flt.multiply(ScaleFactor, LosslessRM);
487 
488   if (OpSema != &FloatSema)
489     Flt.convert(FloatSema, RM, &Ignored);
490 
491   return Flt;
492 }
493 
494 APFixedPoint APFixedPoint::getFromIntValue(const APSInt &Value,
495                                            const FixedPointSemantics &DstFXSema,
496                                            bool *Overflow) {
497   FixedPointSemantics IntFXSema = FixedPointSemantics::GetIntegerSemantics(
498       Value.getBitWidth(), Value.isSigned());
499   return APFixedPoint(Value, IntFXSema).convert(DstFXSema, Overflow);
500 }
501 
502 APFixedPoint
503 APFixedPoint::getFromFloatValue(const APFloat &Value,
504                                 const FixedPointSemantics &DstFXSema,
505                                 bool *Overflow) {
506   // For some operations, rounding mode has an effect on the result, while
507   // other operations are lossless and should never result in rounding.
508   // To signify which these operations are, we define two rounding modes here,
509   // even though they are the same mode.
510   APFloat::roundingMode RM = APFloat::rmTowardZero;
511   APFloat::roundingMode LosslessRM = APFloat::rmTowardZero;
512 
513   const fltSemantics &FloatSema = Value.getSemantics();
514 
515   if (Value.isNaN()) {
516     // Handle NaN immediately.
517     if (Overflow)
518       *Overflow = true;
519     return APFixedPoint(DstFXSema);
520   }
521 
522   // Make sure that we are operating in a type that works with this fixed-point
523   // semantic.
524   const fltSemantics *OpSema = &FloatSema;
525   while (!DstFXSema.fitsInFloatSemantics(*OpSema))
526     OpSema = promoteFloatSemantics(OpSema);
527 
528   APFloat Val = Value;
529 
530   bool Ignored;
531   if (&FloatSema != OpSema)
532     Val.convert(*OpSema, LosslessRM, &Ignored);
533 
534   // Scale up the float so that the 'fractional' part of the mantissa ends up in
535   // the integer range instead. Rounding mode is irrelevant here.
536   // It is fine if this overflows to infinity even for saturating types,
537   // since we will use floating point comparisons to check for saturation.
538   APFloat ScaleFactor(std::pow(2, DstFXSema.getScale()));
539   ScaleFactor.convert(*OpSema, LosslessRM, &Ignored);
540   Val.multiply(ScaleFactor, LosslessRM);
541 
542   // Convert to the integral representation of the value. This rounding mode
543   // is significant.
544   APSInt Res(DstFXSema.getWidth(), !DstFXSema.isSigned());
545   Val.convertToInteger(Res, RM, &Ignored);
546 
547   // Round the integral value and scale back. This makes the
548   // overflow calculations below work properly. If we do not round here,
549   // we risk checking for overflow with a value that is outside the
550   // representable range of the fixed-point semantic even though no overflow
551   // would occur had we rounded first.
552   ScaleFactor = APFloat(std::pow(2, -(int)DstFXSema.getScale()));
553   ScaleFactor.convert(*OpSema, LosslessRM, &Ignored);
554   Val.roundToIntegral(RM);
555   Val.multiply(ScaleFactor, LosslessRM);
556 
557   // Check for overflow/saturation by checking if the floating point value
558   // is outside the range representable by the fixed-point value.
559   APFloat FloatMax = getMax(DstFXSema).convertToFloat(*OpSema);
560   APFloat FloatMin = getMin(DstFXSema).convertToFloat(*OpSema);
561   bool Overflowed = false;
562   if (DstFXSema.isSaturated()) {
563     if (Val > FloatMax)
564       Res = getMax(DstFXSema).getValue();
565     else if (Val < FloatMin)
566       Res = getMin(DstFXSema).getValue();
567   } else
568     Overflowed = Val > FloatMax || Val < FloatMin;
569 
570   if (Overflow)
571     *Overflow = Overflowed;
572 
573   return APFixedPoint(Res, DstFXSema);
574 }
575 
576 } // namespace llvm
577