xref: /llvm-project/flang/lib/Evaluate/complex.cpp (revision 1444e5acfb75630c23b118c39454a05cf3792d35)
1 //===-- lib/Evaluate/complex.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Evaluate/complex.h"
10 #include "llvm/Support/raw_ostream.h"
11 
12 namespace Fortran::evaluate::value {
13 
14 template <typename R>
Add(const Complex & that,Rounding rounding) const15 ValueWithRealFlags<Complex<R>> Complex<R>::Add(
16     const Complex &that, Rounding rounding) const {
17   RealFlags flags;
18   Part reSum{re_.Add(that.re_, rounding).AccumulateFlags(flags)};
19   Part imSum{im_.Add(that.im_, rounding).AccumulateFlags(flags)};
20   return {Complex{reSum, imSum}, flags};
21 }
22 
23 template <typename R>
Subtract(const Complex & that,Rounding rounding) const24 ValueWithRealFlags<Complex<R>> Complex<R>::Subtract(
25     const Complex &that, Rounding rounding) const {
26   RealFlags flags;
27   Part reDiff{re_.Subtract(that.re_, rounding).AccumulateFlags(flags)};
28   Part imDiff{im_.Subtract(that.im_, rounding).AccumulateFlags(flags)};
29   return {Complex{reDiff, imDiff}, flags};
30 }
31 
32 template <typename R>
Multiply(const Complex & that,Rounding rounding) const33 ValueWithRealFlags<Complex<R>> Complex<R>::Multiply(
34     const Complex &that, Rounding rounding) const {
35   // (a + ib)*(c + id) -> ac - bd + i(ad + bc)
36   RealFlags flags;
37   Part ac{re_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
38   Part bd{im_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
39   Part ad{re_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
40   Part bc{im_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
41   Part acbd{ac.Subtract(bd, rounding).AccumulateFlags(flags)};
42   Part adbc{ad.Add(bc, rounding).AccumulateFlags(flags)};
43   return {Complex{acbd, adbc}, flags};
44 }
45 
46 template <typename R>
Divide(const Complex & that,Rounding rounding) const47 ValueWithRealFlags<Complex<R>> Complex<R>::Divide(
48     const Complex &that, Rounding rounding) const {
49   // (a + ib)/(c + id) -> [(a+ib)*(c-id)] / [(c+id)*(c-id)]
50   //   -> [ac+bd+i(bc-ad)] / (cc+dd)  -- note (cc+dd) is real
51   //   -> ((ac+bd)/(cc+dd)) + i((bc-ad)/(cc+dd))
52   RealFlags flags;
53   Part cc{that.re_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
54   Part dd{that.im_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
55   Part ccPdd{cc.Add(dd, rounding).AccumulateFlags(flags)};
56   if (!flags.test(RealFlag::Overflow) && !flags.test(RealFlag::Underflow)) {
57     // den = (cc+dd) did not overflow or underflow; try the naive
58     // sequence without scaling to avoid extra roundings.
59     Part ac{re_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
60     Part ad{re_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
61     Part bc{im_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
62     Part bd{im_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
63     Part acPbd{ac.Add(bd, rounding).AccumulateFlags(flags)};
64     Part bcSad{bc.Subtract(ad, rounding).AccumulateFlags(flags)};
65     Part re{acPbd.Divide(ccPdd, rounding).AccumulateFlags(flags)};
66     Part im{bcSad.Divide(ccPdd, rounding).AccumulateFlags(flags)};
67     if (!flags.test(RealFlag::Overflow) && !flags.test(RealFlag::Underflow)) {
68       return {Complex{re, im}, flags};
69     }
70   }
71   // Scale numerator and denominator by d/c (if c>=d) or c/d (if c<d)
72   flags.clear();
73   Part scale; // will be <= 1.0 in magnitude
74   bool cGEd{that.re_.ABS().Compare(that.im_.ABS()) != Relation::Less};
75   if (cGEd) {
76     scale = that.im_.Divide(that.re_, rounding).AccumulateFlags(flags);
77   } else {
78     scale = that.re_.Divide(that.im_, rounding).AccumulateFlags(flags);
79   }
80   Part den;
81   if (cGEd) {
82     Part dS{scale.Multiply(that.im_, rounding).AccumulateFlags(flags)};
83     den = dS.Add(that.re_, rounding).AccumulateFlags(flags);
84   } else {
85     Part cS{scale.Multiply(that.re_, rounding).AccumulateFlags(flags)};
86     den = cS.Add(that.im_, rounding).AccumulateFlags(flags);
87   }
88   Part aS{scale.Multiply(re_, rounding).AccumulateFlags(flags)};
89   Part bS{scale.Multiply(im_, rounding).AccumulateFlags(flags)};
90   Part re1, im1;
91   if (cGEd) {
92     re1 = re_.Add(bS, rounding).AccumulateFlags(flags);
93     im1 = im_.Subtract(aS, rounding).AccumulateFlags(flags);
94   } else {
95     re1 = aS.Add(im_, rounding).AccumulateFlags(flags);
96     im1 = bS.Subtract(re_, rounding).AccumulateFlags(flags);
97   }
98   Part re{re1.Divide(den, rounding).AccumulateFlags(flags)};
99   Part im{im1.Divide(den, rounding).AccumulateFlags(flags)};
100   return {Complex{re, im}, flags};
101 }
102 
DumpHexadecimal() const103 template <typename R> std::string Complex<R>::DumpHexadecimal() const {
104   std::string result{'('};
105   result += re_.DumpHexadecimal();
106   result += ',';
107   result += im_.DumpHexadecimal();
108   result += ')';
109   return result;
110 }
111 
112 template <typename R>
AsFortran(llvm::raw_ostream & o,int kind) const113 llvm::raw_ostream &Complex<R>::AsFortran(llvm::raw_ostream &o, int kind) const {
114   re_.AsFortran(o << '(', kind);
115   im_.AsFortran(o << ',', kind);
116   return o << ')';
117 }
118 
119 template class Complex<Real<Integer<16>, 11>>;
120 template class Complex<Real<Integer<16>, 8>>;
121 template class Complex<Real<Integer<32>, 24>>;
122 template class Complex<Real<Integer<64>, 53>>;
123 template class Complex<Real<X87IntegerContainer, 64>>;
124 template class Complex<Real<Integer<128>, 113>>;
125 } // namespace Fortran::evaluate::value
126