xref: /llvm-project/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- Polynomial.h - A data class for polynomials --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_
10 #define MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_
11 
12 #include "mlir/Support/LLVM.h"
13 #include "llvm/ADT/APFloat.h"
14 #include "llvm/ADT/APInt.h"
15 #include "llvm/ADT/ArrayRef.h"
16 #include "llvm/ADT/Hashing.h"
17 #include "llvm/ADT/SmallString.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/raw_ostream.h"
20 
21 namespace mlir {
22 
23 class MLIRContext;
24 
25 namespace polynomial {
26 
27 /// This restricts statically defined polynomials to have at most 64-bit
28 /// coefficients. This may be relaxed in the future, but it seems unlikely one
29 /// would want to specify 128-bit polynomials statically in the source code.
30 constexpr unsigned apintBitWidth = 64;
31 
32 template <class Derived, typename CoefficientType>
33 class MonomialBase {
34 public:
MonomialBase(const CoefficientType & coeff,const APInt & expo)35   MonomialBase(const CoefficientType &coeff, const APInt &expo)
36       : coefficient(coeff), exponent(expo) {}
37   virtual ~MonomialBase() = default;
38 
getCoefficient()39   const CoefficientType &getCoefficient() const { return coefficient; }
getMutableCoefficient()40   CoefficientType &getMutableCoefficient() { return coefficient; }
getExponent()41   const APInt &getExponent() const { return exponent; }
setCoefficient(const CoefficientType & coeff)42   void setCoefficient(const CoefficientType &coeff) { coefficient = coeff; }
setExponent(const APInt & exp)43   void setExponent(const APInt &exp) { exponent = exp; }
44 
45   bool operator==(const MonomialBase &other) const {
46     return other.coefficient == coefficient && other.exponent == exponent;
47   }
48   bool operator!=(const MonomialBase &other) const {
49     return other.coefficient != coefficient || other.exponent != exponent;
50   }
51 
52   /// Monomials are ordered by exponent.
53   bool operator<(const MonomialBase &other) const {
54     return (exponent.ult(other.exponent));
55   }
56 
add(const Derived & other)57   Derived add(const Derived &other) {
58     assert(exponent == other.exponent);
59     CoefficientType newCoeff = coefficient + other.coefficient;
60     Derived result;
61     result.setCoefficient(newCoeff);
62     result.setExponent(exponent);
63     return result;
64   }
65 
66   virtual bool isMonic() const = 0;
67   virtual void
68   coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
69 
70   template <class D, typename T>
71   friend ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg);
72 
73 protected:
74   CoefficientType coefficient;
75   APInt exponent;
76 };
77 
78 /// A class representing a monomial of a single-variable polynomial with integer
79 /// coefficients.
80 class IntMonomial : public MonomialBase<IntMonomial, APInt> {
81 public:
IntMonomial(int64_t coeff,uint64_t expo)82   IntMonomial(int64_t coeff, uint64_t expo)
83       : MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
84 
IntMonomial()85   IntMonomial()
86       : MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}
87 
88   ~IntMonomial() override = default;
89 
isMonic()90   bool isMonic() const override { return coefficient == 1; }
91 
coefficientToString(llvm::SmallString<16> & coeffString)92   void coefficientToString(llvm::SmallString<16> &coeffString) const override {
93     coefficient.toStringSigned(coeffString);
94   }
95 };
96 
97 /// A class representing a monomial of a single-variable polynomial with integer
98 /// coefficients.
99 class FloatMonomial : public MonomialBase<FloatMonomial, APFloat> {
100 public:
FloatMonomial(double coeff,uint64_t expo)101   FloatMonomial(double coeff, uint64_t expo)
102       : MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
103 
FloatMonomial()104   FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}
105 
106   ~FloatMonomial() override = default;
107 
isMonic()108   bool isMonic() const override { return coefficient == APFloat(1.0); }
109 
coefficientToString(llvm::SmallString<16> & coeffString)110   void coefficientToString(llvm::SmallString<16> &coeffString) const override {
111     coefficient.toString(coeffString);
112   }
113 };
114 
115 template <class Derived, typename Monomial>
116 class PolynomialBase {
117 public:
118   PolynomialBase() = delete;
119 
PolynomialBase(ArrayRef<Monomial> terms)120   explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms) {};
121 
122   explicit operator bool() const { return !terms.empty(); }
123   bool operator==(const PolynomialBase &other) const {
124     return other.terms == terms;
125   }
126   bool operator!=(const PolynomialBase &other) const {
127     return !(other.terms == terms);
128   }
129 
print(raw_ostream & os,::llvm::StringRef separator,::llvm::StringRef exponentiation)130   void print(raw_ostream &os, ::llvm::StringRef separator,
131              ::llvm::StringRef exponentiation) const {
132     bool first = true;
133     for (const Monomial &term : getTerms()) {
134       if (first) {
135         first = false;
136       } else {
137         os << separator;
138       }
139       std::string coeffToPrint;
140       if (term.isMonic() && term.getExponent().uge(1)) {
141         coeffToPrint = "";
142       } else {
143         llvm::SmallString<16> coeffString;
144         term.coefficientToString(coeffString);
145         coeffToPrint = coeffString.str();
146       }
147 
148       if (term.getExponent() == 0) {
149         os << coeffToPrint;
150       } else if (term.getExponent() == 1) {
151         os << coeffToPrint << "x";
152       } else {
153         llvm::SmallString<16> expString;
154         term.getExponent().toStringSigned(expString);
155         os << coeffToPrint << "x" << exponentiation << expString;
156       }
157     }
158   }
159 
add(const Derived & other)160   Derived add(const Derived &other) {
161     SmallVector<Monomial> newTerms;
162     auto it1 = terms.begin();
163     auto it2 = other.terms.begin();
164     while (it1 != terms.end() || it2 != other.terms.end()) {
165       if (it1 == terms.end()) {
166         newTerms.emplace_back(*it2);
167         it2++;
168         continue;
169       }
170 
171       if (it2 == other.terms.end()) {
172         newTerms.emplace_back(*it1);
173         it1++;
174         continue;
175       }
176 
177       while (it1->getExponent().ult(it2->getExponent())) {
178         newTerms.emplace_back(*it1);
179         it1++;
180         if (it1 == terms.end())
181           break;
182       }
183 
184       while (it2->getExponent().ult(it1->getExponent())) {
185         newTerms.emplace_back(*it2);
186         it2++;
187         if (it2 == terms.end())
188           break;
189       }
190 
191       newTerms.emplace_back(it1->add(*it2));
192       it1++;
193       it2++;
194     }
195     return Derived(newTerms);
196   }
197 
198   // Prints polynomial to 'os'.
print(raw_ostream & os)199   void print(raw_ostream &os) const { print(os, " + ", "**"); }
200 
201   void dump() const;
202 
203   // Prints polynomial so that it can be used as a valid identifier
toIdentifier()204   std::string toIdentifier() const {
205     std::string result;
206     llvm::raw_string_ostream os(result);
207     print(os, "_", "");
208     return os.str();
209   }
210 
getDegree()211   unsigned getDegree() const {
212     return terms.back().getExponent().getZExtValue();
213   }
214 
getTerms()215   ArrayRef<Monomial> getTerms() const { return terms; }
216 
217   template <class D, typename T>
218   friend ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg);
219 
220 private:
221   // The monomial terms for this polynomial.
222   SmallVector<Monomial> terms;
223 };
224 
225 /// A single-variable polynomial with integer coefficients.
226 ///
227 /// Eg: x^1024 + x + 1
228 class IntPolynomial : public PolynomialBase<IntPolynomial, IntMonomial> {
229 public:
IntPolynomial(ArrayRef<IntMonomial> terms)230   explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
231 
232   // Returns a Polynomial from a list of monomials.
233   // Fails if two monomials have the same exponent.
234   static FailureOr<IntPolynomial>
235   fromMonomials(ArrayRef<IntMonomial> monomials);
236 
237   /// Returns a polynomial with coefficients given by `coeffs`. The value
238   /// coeffs[i] is converted to a monomial with exponent i.
239   static IntPolynomial fromCoefficients(ArrayRef<int64_t> coeffs);
240 };
241 
242 /// A single-variable polynomial with double coefficients.
243 ///
244 /// Eg: 1.0 x^1024 + 3.5 x + 1e-05
245 class FloatPolynomial : public PolynomialBase<FloatPolynomial, FloatMonomial> {
246 public:
FloatPolynomial(ArrayRef<FloatMonomial> terms)247   explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
248       : PolynomialBase(terms) {}
249 
250   // Returns a Polynomial from a list of monomials.
251   // Fails if two monomials have the same exponent.
252   static FailureOr<FloatPolynomial>
253   fromMonomials(ArrayRef<FloatMonomial> monomials);
254 
255   /// Returns a polynomial with coefficients given by `coeffs`. The value
256   /// coeffs[i] is converted to a monomial with exponent i.
257   static FloatPolynomial fromCoefficients(ArrayRef<double> coeffs);
258 };
259 
260 // Make Polynomials hashable.
261 template <class D, typename T>
hash_value(const PolynomialBase<D,T> & arg)262 inline ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg) {
263   return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
264 }
265 
266 template <class D, typename T>
hash_value(const MonomialBase<D,T> & arg)267 inline ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg) {
268   return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
269                             ::llvm::hash_value(arg.exponent));
270 }
271 
272 template <class D, typename T>
273 inline raw_ostream &operator<<(raw_ostream &os,
274                                const PolynomialBase<D, T> &polynomial) {
275   polynomial.print(os);
276   return os;
277 }
278 
279 } // namespace polynomial
280 } // namespace mlir
281 
282 #endif // MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_
283