xref: /llvm-project/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp (revision 4425dfba6a1f394e958e94aa471a07bcf707136a)
155b6f170SJeremy Kun //===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===//
255b6f170SJeremy Kun //
355b6f170SJeremy Kun // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
455b6f170SJeremy Kun // See https://llvm.org/LICENSE.txt for license information.
555b6f170SJeremy Kun // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
655b6f170SJeremy Kun //
755b6f170SJeremy Kun //===----------------------------------------------------------------------===//
855b6f170SJeremy Kun #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
955b6f170SJeremy Kun 
1055b6f170SJeremy Kun #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
1155b6f170SJeremy Kun #include "mlir/Support/LLVM.h"
122ff43ce8SJeremy Kun #include "llvm/ADT/SmallVector.h"
1355b6f170SJeremy Kun #include "llvm/ADT/StringExtras.h"
1455b6f170SJeremy Kun #include "llvm/ADT/StringRef.h"
1555b6f170SJeremy Kun #include "llvm/ADT/StringSet.h"
1655b6f170SJeremy Kun 
1755b6f170SJeremy Kun namespace mlir {
1855b6f170SJeremy Kun namespace polynomial {
1955b6f170SJeremy Kun 
202ff43ce8SJeremy Kun void IntPolynomialAttr::print(AsmPrinter &p) const {
212ff43ce8SJeremy Kun   p << '<' << getPolynomial() << '>';
2255b6f170SJeremy Kun }
2355b6f170SJeremy Kun 
242ff43ce8SJeremy Kun void FloatPolynomialAttr::print(AsmPrinter &p) const {
252ff43ce8SJeremy Kun   p << '<' << getPolynomial() << '>';
262ff43ce8SJeremy Kun }
272ff43ce8SJeremy Kun 
282ff43ce8SJeremy Kun /// A callable that parses the coefficient using the appropriate method for the
292ff43ce8SJeremy Kun /// given monomial type, and stores the parsed coefficient value on the
302ff43ce8SJeremy Kun /// monomial.
312ff43ce8SJeremy Kun template <typename MonomialType>
322ff43ce8SJeremy Kun using ParseCoefficientFn = std::function<OptionalParseResult(MonomialType &)>;
332ff43ce8SJeremy Kun 
3455b6f170SJeremy Kun /// Try to parse a monomial. If successful, populate the fields of the outparam
3555b6f170SJeremy Kun /// `monomial` with the results, and the `variable` outparam with the parsed
3655b6f170SJeremy Kun /// variable name. Sets shouldParseMore to true if the monomial is followed by
3755b6f170SJeremy Kun /// a '+'.
382ff43ce8SJeremy Kun ///
392ff43ce8SJeremy Kun template <typename Monomial>
402ff43ce8SJeremy Kun ParseResult
412ff43ce8SJeremy Kun parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
422ff43ce8SJeremy Kun               bool &isConstantTerm, bool &shouldParseMore,
432ff43ce8SJeremy Kun               ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
442ff43ce8SJeremy Kun   OptionalParseResult parsedCoeffResult = parseAndStoreCoefficient(monomial);
4555b6f170SJeremy Kun 
4655b6f170SJeremy Kun   isConstantTerm = false;
4755b6f170SJeremy Kun   shouldParseMore = false;
4855b6f170SJeremy Kun 
4955b6f170SJeremy Kun   // A + indicates it's a constant term with more to go, as in `1 + x`.
5055b6f170SJeremy Kun   if (succeeded(parser.parseOptionalPlus())) {
5155b6f170SJeremy Kun     // If no coefficient was parsed, and there's a +, then it's effectively
5255b6f170SJeremy Kun     // parsing an empty string.
5355b6f170SJeremy Kun     if (!parsedCoeffResult.has_value()) {
5455b6f170SJeremy Kun       return failure();
5555b6f170SJeremy Kun     }
562ff43ce8SJeremy Kun     monomial.setExponent(APInt(apintBitWidth, 0));
5755b6f170SJeremy Kun     isConstantTerm = true;
5855b6f170SJeremy Kun     shouldParseMore = true;
5955b6f170SJeremy Kun     return success();
6055b6f170SJeremy Kun   }
6155b6f170SJeremy Kun 
6255b6f170SJeremy Kun   // A monomial can be a trailing constant term, as in `x + 1`.
6355b6f170SJeremy Kun   if (failed(parser.parseOptionalKeyword(&variable))) {
6455b6f170SJeremy Kun     // If neither a coefficient nor a variable was found, then it's effectively
6555b6f170SJeremy Kun     // parsing an empty string.
6655b6f170SJeremy Kun     if (!parsedCoeffResult.has_value()) {
6755b6f170SJeremy Kun       return failure();
6855b6f170SJeremy Kun     }
6955b6f170SJeremy Kun 
702ff43ce8SJeremy Kun     monomial.setExponent(APInt(apintBitWidth, 0));
7155b6f170SJeremy Kun     isConstantTerm = true;
7255b6f170SJeremy Kun     return success();
7355b6f170SJeremy Kun   }
7455b6f170SJeremy Kun 
7555b6f170SJeremy Kun   // Parse exponentiation symbol as `**`. We can't use caret because it's
7655b6f170SJeremy Kun   // reserved for basic block identifiers If no star is present, it's treated
7755b6f170SJeremy Kun   // as a polynomial with exponent 1.
7855b6f170SJeremy Kun   if (succeeded(parser.parseOptionalStar())) {
7955b6f170SJeremy Kun     // If there's one * there must be two.
8055b6f170SJeremy Kun     if (failed(parser.parseStar())) {
8155b6f170SJeremy Kun       return failure();
8255b6f170SJeremy Kun     }
8355b6f170SJeremy Kun 
8455b6f170SJeremy Kun     // If there's a **, then the integer exponent is required.
8555b6f170SJeremy Kun     APInt parsedExponent(apintBitWidth, 0);
8655b6f170SJeremy Kun     if (failed(parser.parseInteger(parsedExponent))) {
8755b6f170SJeremy Kun       parser.emitError(parser.getCurrentLocation(),
8855b6f170SJeremy Kun                        "found invalid integer exponent");
8955b6f170SJeremy Kun       return failure();
9055b6f170SJeremy Kun     }
9155b6f170SJeremy Kun 
922ff43ce8SJeremy Kun     monomial.setExponent(parsedExponent);
9355b6f170SJeremy Kun   } else {
942ff43ce8SJeremy Kun     monomial.setExponent(APInt(apintBitWidth, 1));
9555b6f170SJeremy Kun   }
9655b6f170SJeremy Kun 
9755b6f170SJeremy Kun   if (succeeded(parser.parseOptionalPlus())) {
9855b6f170SJeremy Kun     shouldParseMore = true;
9955b6f170SJeremy Kun   }
10055b6f170SJeremy Kun   return success();
10155b6f170SJeremy Kun }
10255b6f170SJeremy Kun 
103ab29203eSJeremy Kun template <typename Monomial>
1042ff43ce8SJeremy Kun LogicalResult
1052ff43ce8SJeremy Kun parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
1062ff43ce8SJeremy Kun                     llvm::StringSet<> &variables,
1072ff43ce8SJeremy Kun                     ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
10855b6f170SJeremy Kun   while (true) {
10955b6f170SJeremy Kun     Monomial parsedMonomial;
11055b6f170SJeremy Kun     llvm::StringRef parsedVariableRef;
11155b6f170SJeremy Kun     bool isConstantTerm;
11255b6f170SJeremy Kun     bool shouldParseMore;
1132ff43ce8SJeremy Kun     if (failed(parseMonomial<Monomial>(
1142ff43ce8SJeremy Kun             parser, parsedMonomial, parsedVariableRef, isConstantTerm,
1152ff43ce8SJeremy Kun             shouldParseMore, parseAndStoreCoefficient))) {
11655b6f170SJeremy Kun       parser.emitError(parser.getCurrentLocation(), "expected a monomial");
1172ff43ce8SJeremy Kun       return failure();
11855b6f170SJeremy Kun     }
11955b6f170SJeremy Kun 
12055b6f170SJeremy Kun     if (!isConstantTerm) {
12155b6f170SJeremy Kun       std::string parsedVariable = parsedVariableRef.str();
12255b6f170SJeremy Kun       variables.insert(parsedVariable);
12355b6f170SJeremy Kun     }
12455b6f170SJeremy Kun     monomials.push_back(parsedMonomial);
12555b6f170SJeremy Kun 
12655b6f170SJeremy Kun     if (shouldParseMore)
12755b6f170SJeremy Kun       continue;
12855b6f170SJeremy Kun 
12955b6f170SJeremy Kun     if (succeeded(parser.parseOptionalGreater())) {
13055b6f170SJeremy Kun       break;
13155b6f170SJeremy Kun     }
13255b6f170SJeremy Kun     parser.emitError(
13355b6f170SJeremy Kun         parser.getCurrentLocation(),
13455b6f170SJeremy Kun         "expected + and more monomials, or > to end polynomial attribute");
1352ff43ce8SJeremy Kun     return failure();
13655b6f170SJeremy Kun   }
13755b6f170SJeremy Kun 
13855b6f170SJeremy Kun   if (variables.size() > 1) {
13955b6f170SJeremy Kun     std::string vars = llvm::join(variables.keys(), ", ");
14055b6f170SJeremy Kun     parser.emitError(
14155b6f170SJeremy Kun         parser.getCurrentLocation(),
14255b6f170SJeremy Kun         "polynomials must have one indeterminate, but there were multiple: " +
14355b6f170SJeremy Kun             vars);
1442ff43ce8SJeremy Kun     return failure();
14555b6f170SJeremy Kun   }
14655b6f170SJeremy Kun 
1472ff43ce8SJeremy Kun   return success();
1482ff43ce8SJeremy Kun }
1492ff43ce8SJeremy Kun 
1502ff43ce8SJeremy Kun Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) {
1512ff43ce8SJeremy Kun   if (failed(parser.parseLess()))
1522ff43ce8SJeremy Kun     return {};
1532ff43ce8SJeremy Kun 
1542ff43ce8SJeremy Kun   llvm::SmallVector<IntMonomial> monomials;
1552ff43ce8SJeremy Kun   llvm::StringSet<> variables;
1562ff43ce8SJeremy Kun 
157ab29203eSJeremy Kun   if (failed(parsePolynomialAttr<IntMonomial>(
1582ff43ce8SJeremy Kun           parser, monomials, variables,
1592ff43ce8SJeremy Kun           [&](IntMonomial &monomial) -> OptionalParseResult {
1602ff43ce8SJeremy Kun             APInt parsedCoeff(apintBitWidth, 1);
1612ff43ce8SJeremy Kun             OptionalParseResult result =
1622ff43ce8SJeremy Kun                 parser.parseOptionalInteger(parsedCoeff);
1632ff43ce8SJeremy Kun             monomial.setCoefficient(parsedCoeff);
1642ff43ce8SJeremy Kun             return result;
1652ff43ce8SJeremy Kun           }))) {
1662ff43ce8SJeremy Kun     return {};
1672ff43ce8SJeremy Kun   }
1682ff43ce8SJeremy Kun 
1692ff43ce8SJeremy Kun   auto result = IntPolynomial::fromMonomials(monomials);
17055b6f170SJeremy Kun   if (failed(result)) {
17155b6f170SJeremy Kun     parser.emitError(parser.getCurrentLocation())
17255b6f170SJeremy Kun         << "parsed polynomial must have unique exponents among monomials";
17355b6f170SJeremy Kun     return {};
17455b6f170SJeremy Kun   }
1752ff43ce8SJeremy Kun   return IntPolynomialAttr::get(parser.getContext(), result.value());
17655b6f170SJeremy Kun }
1772ff43ce8SJeremy Kun Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
17855b6f170SJeremy Kun   if (failed(parser.parseLess()))
17955b6f170SJeremy Kun     return {};
18055b6f170SJeremy Kun 
1812ff43ce8SJeremy Kun   llvm::SmallVector<FloatMonomial> monomials;
1822ff43ce8SJeremy Kun   llvm::StringSet<> variables;
18355b6f170SJeremy Kun 
1842ff43ce8SJeremy Kun   ParseCoefficientFn<FloatMonomial> parseAndStoreCoefficient =
1852ff43ce8SJeremy Kun       [&](FloatMonomial &monomial) -> OptionalParseResult {
1862ff43ce8SJeremy Kun     double coeffValue = 1.0;
1872ff43ce8SJeremy Kun     ParseResult result = parser.parseFloat(coeffValue);
1882ff43ce8SJeremy Kun     monomial.setCoefficient(APFloat(coeffValue));
1892ff43ce8SJeremy Kun     return OptionalParseResult(result);
1902ff43ce8SJeremy Kun   };
19155b6f170SJeremy Kun 
192ab29203eSJeremy Kun   if (failed(parsePolynomialAttr<FloatMonomial>(parser, monomials, variables,
193ab29203eSJeremy Kun                                                 parseAndStoreCoefficient))) {
19455b6f170SJeremy Kun     return {};
19555b6f170SJeremy Kun   }
1962ff43ce8SJeremy Kun 
1972ff43ce8SJeremy Kun   auto result = FloatPolynomial::fromMonomials(monomials);
19855b6f170SJeremy Kun   if (failed(result)) {
1992ff43ce8SJeremy Kun     parser.emitError(parser.getCurrentLocation())
2002ff43ce8SJeremy Kun         << "parsed polynomial must have unique exponents among monomials";
20155b6f170SJeremy Kun     return {};
20255b6f170SJeremy Kun   }
2032ff43ce8SJeremy Kun   return FloatPolynomialAttr::get(parser.getContext(), result.value());
20455b6f170SJeremy Kun }
20555b6f170SJeremy Kun 
206*4425dfbaSHongren Zheng LogicalResult
207*4425dfbaSHongren Zheng RingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
208*4425dfbaSHongren Zheng                  Type coefficientType, IntegerAttr coefficientModulus,
209*4425dfbaSHongren Zheng                  IntPolynomialAttr polynomialModulus) {
210*4425dfbaSHongren Zheng   if (coefficientModulus) {
211*4425dfbaSHongren Zheng     auto coeffIntType = llvm::dyn_cast<IntegerType>(coefficientType);
212*4425dfbaSHongren Zheng     if (!coeffIntType) {
213*4425dfbaSHongren Zheng       return emitError() << "coefficientModulus specified but coefficientType "
214*4425dfbaSHongren Zheng                             "is not integral";
215*4425dfbaSHongren Zheng     }
216*4425dfbaSHongren Zheng     APInt coeffModValue = coefficientModulus.getValue();
217*4425dfbaSHongren Zheng     if (coeffModValue == 0) {
218*4425dfbaSHongren Zheng       return emitError() << "coefficientModulus should not be 0";
219*4425dfbaSHongren Zheng     }
220*4425dfbaSHongren Zheng     if (coeffModValue.slt(0)) {
221*4425dfbaSHongren Zheng       return emitError() << "coefficientModulus should be positive";
222*4425dfbaSHongren Zheng     }
223*4425dfbaSHongren Zheng     auto coeffModWidth = (coeffModValue - 1).getActiveBits();
224*4425dfbaSHongren Zheng     auto coeffWidth = coeffIntType.getWidth();
225*4425dfbaSHongren Zheng     if (coeffModWidth > coeffWidth) {
226*4425dfbaSHongren Zheng       return emitError() << "coefficientModulus needs bit width of "
227*4425dfbaSHongren Zheng                          << coeffModWidth
228*4425dfbaSHongren Zheng                          << " but coefficientType can only contain "
229*4425dfbaSHongren Zheng                          << coeffWidth << " bits";
230*4425dfbaSHongren Zheng     }
231*4425dfbaSHongren Zheng   }
232*4425dfbaSHongren Zheng   return success();
233*4425dfbaSHongren Zheng }
234*4425dfbaSHongren Zheng 
23555b6f170SJeremy Kun } // namespace polynomial
23655b6f170SJeremy Kun } // namespace mlir
237