xref: /llvm-project/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
155b6f170SJeremy Kun //===- Polynomial.cpp - MLIR storage type for static Polynomial -*- C++ -*-===//
255b6f170SJeremy Kun //
355b6f170SJeremy Kun // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
455b6f170SJeremy Kun // See https://llvm.org/LICENSE.txt for license information.
555b6f170SJeremy Kun // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
655b6f170SJeremy Kun //
755b6f170SJeremy Kun //===----------------------------------------------------------------------===//
855b6f170SJeremy Kun 
955b6f170SJeremy Kun #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
1055b6f170SJeremy Kun 
1155b6f170SJeremy Kun #include "llvm/ADT/SmallVector.h"
1255b6f170SJeremy Kun 
1355b6f170SJeremy Kun namespace mlir {
1455b6f170SJeremy Kun namespace polynomial {
1555b6f170SJeremy Kun 
16*2ff43ce8SJeremy Kun template <typename PolyT, typename MonomialT>
fromMonomialsImpl(ArrayRef<MonomialT> monomials)17*2ff43ce8SJeremy Kun FailureOr<PolyT> fromMonomialsImpl(ArrayRef<MonomialT> monomials) {
1855b6f170SJeremy Kun   // A polynomial's terms are canonically stored in order of increasing degree.
19*2ff43ce8SJeremy Kun   auto monomialsCopy = llvm::SmallVector<MonomialT>(monomials);
2055b6f170SJeremy Kun   std::sort(monomialsCopy.begin(), monomialsCopy.end());
2155b6f170SJeremy Kun 
2255b6f170SJeremy Kun   // Ensure non-unique exponents are not present. Since we sorted the list by
2355b6f170SJeremy Kun   // exponent, a linear scan of adjancent monomials suffices.
2455b6f170SJeremy Kun   if (std::adjacent_find(monomialsCopy.begin(), monomialsCopy.end(),
25*2ff43ce8SJeremy Kun                          [](const MonomialT &lhs, const MonomialT &rhs) {
26*2ff43ce8SJeremy Kun                            return lhs.getExponent() == rhs.getExponent();
2755b6f170SJeremy Kun                          }) != monomialsCopy.end()) {
2855b6f170SJeremy Kun     return failure();
2955b6f170SJeremy Kun   }
3055b6f170SJeremy Kun 
31*2ff43ce8SJeremy Kun   return PolyT(monomialsCopy);
3255b6f170SJeremy Kun }
3355b6f170SJeremy Kun 
34*2ff43ce8SJeremy Kun FailureOr<IntPolynomial>
fromMonomials(ArrayRef<IntMonomial> monomials)35*2ff43ce8SJeremy Kun IntPolynomial::fromMonomials(ArrayRef<IntMonomial> monomials) {
36*2ff43ce8SJeremy Kun   return fromMonomialsImpl<IntPolynomial, IntMonomial>(monomials);
37*2ff43ce8SJeremy Kun }
38*2ff43ce8SJeremy Kun 
39*2ff43ce8SJeremy Kun FailureOr<FloatPolynomial>
fromMonomials(ArrayRef<FloatMonomial> monomials)40*2ff43ce8SJeremy Kun FloatPolynomial::fromMonomials(ArrayRef<FloatMonomial> monomials) {
41*2ff43ce8SJeremy Kun   return fromMonomialsImpl<FloatPolynomial, FloatMonomial>(monomials);
42*2ff43ce8SJeremy Kun }
43*2ff43ce8SJeremy Kun 
44*2ff43ce8SJeremy Kun template <typename PolyT, typename MonomialT, typename CoeffT>
fromCoefficientsImpl(ArrayRef<CoeffT> coeffs)45*2ff43ce8SJeremy Kun PolyT fromCoefficientsImpl(ArrayRef<CoeffT> coeffs) {
46*2ff43ce8SJeremy Kun   llvm::SmallVector<MonomialT> monomials;
4755b6f170SJeremy Kun   auto size = coeffs.size();
4855b6f170SJeremy Kun   monomials.reserve(size);
4955b6f170SJeremy Kun   for (size_t i = 0; i < size; i++) {
5055b6f170SJeremy Kun     monomials.emplace_back(coeffs[i], i);
5155b6f170SJeremy Kun   }
52*2ff43ce8SJeremy Kun   auto result = PolyT::fromMonomials(monomials);
5355b6f170SJeremy Kun   // Construction guarantees unique exponents, so the failure mode of
5455b6f170SJeremy Kun   // fromMonomials can be bypassed.
5555b6f170SJeremy Kun   assert(succeeded(result));
5655b6f170SJeremy Kun   return result.value();
5755b6f170SJeremy Kun }
5855b6f170SJeremy Kun 
fromCoefficients(ArrayRef<int64_t> coeffs)59*2ff43ce8SJeremy Kun IntPolynomial IntPolynomial::fromCoefficients(ArrayRef<int64_t> coeffs) {
60*2ff43ce8SJeremy Kun   return fromCoefficientsImpl<IntPolynomial, IntMonomial, int64_t>(coeffs);
6155b6f170SJeremy Kun }
6255b6f170SJeremy Kun 
fromCoefficients(ArrayRef<double> coeffs)63*2ff43ce8SJeremy Kun FloatPolynomial FloatPolynomial::fromCoefficients(ArrayRef<double> coeffs) {
64*2ff43ce8SJeremy Kun   return fromCoefficientsImpl<FloatPolynomial, FloatMonomial, double>(coeffs);
6555b6f170SJeremy Kun }
6655b6f170SJeremy Kun 
6755b6f170SJeremy Kun } // namespace polynomial
6855b6f170SJeremy Kun } // namespace mlir
69