1 //===- Polynomial.h - A data class for polynomials --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_
10 #define MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_
11
12 #include "mlir/Support/LLVM.h"
13 #include "llvm/ADT/APFloat.h"
14 #include "llvm/ADT/APInt.h"
15 #include "llvm/ADT/ArrayRef.h"
16 #include "llvm/ADT/Hashing.h"
17 #include "llvm/ADT/SmallString.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/raw_ostream.h"
20
21 namespace mlir {
22
23 class MLIRContext;
24
25 namespace polynomial {
26
27 /// This restricts statically defined polynomials to have at most 64-bit
28 /// coefficients. This may be relaxed in the future, but it seems unlikely one
29 /// would want to specify 128-bit polynomials statically in the source code.
30 constexpr unsigned apintBitWidth = 64;
31
32 template <class Derived, typename CoefficientType>
33 class MonomialBase {
34 public:
MonomialBase(const CoefficientType & coeff,const APInt & expo)35 MonomialBase(const CoefficientType &coeff, const APInt &expo)
36 : coefficient(coeff), exponent(expo) {}
37 virtual ~MonomialBase() = default;
38
getCoefficient()39 const CoefficientType &getCoefficient() const { return coefficient; }
getMutableCoefficient()40 CoefficientType &getMutableCoefficient() { return coefficient; }
getExponent()41 const APInt &getExponent() const { return exponent; }
setCoefficient(const CoefficientType & coeff)42 void setCoefficient(const CoefficientType &coeff) { coefficient = coeff; }
setExponent(const APInt & exp)43 void setExponent(const APInt &exp) { exponent = exp; }
44
45 bool operator==(const MonomialBase &other) const {
46 return other.coefficient == coefficient && other.exponent == exponent;
47 }
48 bool operator!=(const MonomialBase &other) const {
49 return other.coefficient != coefficient || other.exponent != exponent;
50 }
51
52 /// Monomials are ordered by exponent.
53 bool operator<(const MonomialBase &other) const {
54 return (exponent.ult(other.exponent));
55 }
56
add(const Derived & other)57 Derived add(const Derived &other) {
58 assert(exponent == other.exponent);
59 CoefficientType newCoeff = coefficient + other.coefficient;
60 Derived result;
61 result.setCoefficient(newCoeff);
62 result.setExponent(exponent);
63 return result;
64 }
65
66 virtual bool isMonic() const = 0;
67 virtual void
68 coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
69
70 template <class D, typename T>
71 friend ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg);
72
73 protected:
74 CoefficientType coefficient;
75 APInt exponent;
76 };
77
78 /// A class representing a monomial of a single-variable polynomial with integer
79 /// coefficients.
80 class IntMonomial : public MonomialBase<IntMonomial, APInt> {
81 public:
IntMonomial(int64_t coeff,uint64_t expo)82 IntMonomial(int64_t coeff, uint64_t expo)
83 : MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
84
IntMonomial()85 IntMonomial()
86 : MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}
87
88 ~IntMonomial() override = default;
89
isMonic()90 bool isMonic() const override { return coefficient == 1; }
91
coefficientToString(llvm::SmallString<16> & coeffString)92 void coefficientToString(llvm::SmallString<16> &coeffString) const override {
93 coefficient.toStringSigned(coeffString);
94 }
95 };
96
97 /// A class representing a monomial of a single-variable polynomial with integer
98 /// coefficients.
99 class FloatMonomial : public MonomialBase<FloatMonomial, APFloat> {
100 public:
FloatMonomial(double coeff,uint64_t expo)101 FloatMonomial(double coeff, uint64_t expo)
102 : MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
103
FloatMonomial()104 FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}
105
106 ~FloatMonomial() override = default;
107
isMonic()108 bool isMonic() const override { return coefficient == APFloat(1.0); }
109
coefficientToString(llvm::SmallString<16> & coeffString)110 void coefficientToString(llvm::SmallString<16> &coeffString) const override {
111 coefficient.toString(coeffString);
112 }
113 };
114
115 template <class Derived, typename Monomial>
116 class PolynomialBase {
117 public:
118 PolynomialBase() = delete;
119
PolynomialBase(ArrayRef<Monomial> terms)120 explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms) {};
121
122 explicit operator bool() const { return !terms.empty(); }
123 bool operator==(const PolynomialBase &other) const {
124 return other.terms == terms;
125 }
126 bool operator!=(const PolynomialBase &other) const {
127 return !(other.terms == terms);
128 }
129
print(raw_ostream & os,::llvm::StringRef separator,::llvm::StringRef exponentiation)130 void print(raw_ostream &os, ::llvm::StringRef separator,
131 ::llvm::StringRef exponentiation) const {
132 bool first = true;
133 for (const Monomial &term : getTerms()) {
134 if (first) {
135 first = false;
136 } else {
137 os << separator;
138 }
139 std::string coeffToPrint;
140 if (term.isMonic() && term.getExponent().uge(1)) {
141 coeffToPrint = "";
142 } else {
143 llvm::SmallString<16> coeffString;
144 term.coefficientToString(coeffString);
145 coeffToPrint = coeffString.str();
146 }
147
148 if (term.getExponent() == 0) {
149 os << coeffToPrint;
150 } else if (term.getExponent() == 1) {
151 os << coeffToPrint << "x";
152 } else {
153 llvm::SmallString<16> expString;
154 term.getExponent().toStringSigned(expString);
155 os << coeffToPrint << "x" << exponentiation << expString;
156 }
157 }
158 }
159
add(const Derived & other)160 Derived add(const Derived &other) {
161 SmallVector<Monomial> newTerms;
162 auto it1 = terms.begin();
163 auto it2 = other.terms.begin();
164 while (it1 != terms.end() || it2 != other.terms.end()) {
165 if (it1 == terms.end()) {
166 newTerms.emplace_back(*it2);
167 it2++;
168 continue;
169 }
170
171 if (it2 == other.terms.end()) {
172 newTerms.emplace_back(*it1);
173 it1++;
174 continue;
175 }
176
177 while (it1->getExponent().ult(it2->getExponent())) {
178 newTerms.emplace_back(*it1);
179 it1++;
180 if (it1 == terms.end())
181 break;
182 }
183
184 while (it2->getExponent().ult(it1->getExponent())) {
185 newTerms.emplace_back(*it2);
186 it2++;
187 if (it2 == terms.end())
188 break;
189 }
190
191 newTerms.emplace_back(it1->add(*it2));
192 it1++;
193 it2++;
194 }
195 return Derived(newTerms);
196 }
197
198 // Prints polynomial to 'os'.
print(raw_ostream & os)199 void print(raw_ostream &os) const { print(os, " + ", "**"); }
200
201 void dump() const;
202
203 // Prints polynomial so that it can be used as a valid identifier
toIdentifier()204 std::string toIdentifier() const {
205 std::string result;
206 llvm::raw_string_ostream os(result);
207 print(os, "_", "");
208 return os.str();
209 }
210
getDegree()211 unsigned getDegree() const {
212 return terms.back().getExponent().getZExtValue();
213 }
214
getTerms()215 ArrayRef<Monomial> getTerms() const { return terms; }
216
217 template <class D, typename T>
218 friend ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg);
219
220 private:
221 // The monomial terms for this polynomial.
222 SmallVector<Monomial> terms;
223 };
224
225 /// A single-variable polynomial with integer coefficients.
226 ///
227 /// Eg: x^1024 + x + 1
228 class IntPolynomial : public PolynomialBase<IntPolynomial, IntMonomial> {
229 public:
IntPolynomial(ArrayRef<IntMonomial> terms)230 explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
231
232 // Returns a Polynomial from a list of monomials.
233 // Fails if two monomials have the same exponent.
234 static FailureOr<IntPolynomial>
235 fromMonomials(ArrayRef<IntMonomial> monomials);
236
237 /// Returns a polynomial with coefficients given by `coeffs`. The value
238 /// coeffs[i] is converted to a monomial with exponent i.
239 static IntPolynomial fromCoefficients(ArrayRef<int64_t> coeffs);
240 };
241
242 /// A single-variable polynomial with double coefficients.
243 ///
244 /// Eg: 1.0 x^1024 + 3.5 x + 1e-05
245 class FloatPolynomial : public PolynomialBase<FloatPolynomial, FloatMonomial> {
246 public:
FloatPolynomial(ArrayRef<FloatMonomial> terms)247 explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
248 : PolynomialBase(terms) {}
249
250 // Returns a Polynomial from a list of monomials.
251 // Fails if two monomials have the same exponent.
252 static FailureOr<FloatPolynomial>
253 fromMonomials(ArrayRef<FloatMonomial> monomials);
254
255 /// Returns a polynomial with coefficients given by `coeffs`. The value
256 /// coeffs[i] is converted to a monomial with exponent i.
257 static FloatPolynomial fromCoefficients(ArrayRef<double> coeffs);
258 };
259
260 // Make Polynomials hashable.
261 template <class D, typename T>
hash_value(const PolynomialBase<D,T> & arg)262 inline ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg) {
263 return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
264 }
265
266 template <class D, typename T>
hash_value(const MonomialBase<D,T> & arg)267 inline ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg) {
268 return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
269 ::llvm::hash_value(arg.exponent));
270 }
271
272 template <class D, typename T>
273 inline raw_ostream &operator<<(raw_ostream &os,
274 const PolynomialBase<D, T> &polynomial) {
275 polynomial.print(os);
276 return os;
277 }
278
279 } // namespace polynomial
280 } // namespace mlir
281
282 #endif // MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_
283