155b6f170SJeremy Kun //===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===// 255b6f170SJeremy Kun // 355b6f170SJeremy Kun // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 455b6f170SJeremy Kun // See https://llvm.org/LICENSE.txt for license information. 555b6f170SJeremy Kun // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 655b6f170SJeremy Kun // 755b6f170SJeremy Kun //===----------------------------------------------------------------------===// 855b6f170SJeremy Kun #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" 955b6f170SJeremy Kun 1055b6f170SJeremy Kun #include "mlir/Dialect/Polynomial/IR/Polynomial.h" 1155b6f170SJeremy Kun #include "mlir/Support/LLVM.h" 122ff43ce8SJeremy Kun #include "llvm/ADT/SmallVector.h" 1355b6f170SJeremy Kun #include "llvm/ADT/StringExtras.h" 1455b6f170SJeremy Kun #include "llvm/ADT/StringRef.h" 1555b6f170SJeremy Kun #include "llvm/ADT/StringSet.h" 1655b6f170SJeremy Kun 1755b6f170SJeremy Kun namespace mlir { 1855b6f170SJeremy Kun namespace polynomial { 1955b6f170SJeremy Kun 202ff43ce8SJeremy Kun void IntPolynomialAttr::print(AsmPrinter &p) const { 212ff43ce8SJeremy Kun p << '<' << getPolynomial() << '>'; 2255b6f170SJeremy Kun } 2355b6f170SJeremy Kun 242ff43ce8SJeremy Kun void FloatPolynomialAttr::print(AsmPrinter &p) const { 252ff43ce8SJeremy Kun p << '<' << getPolynomial() << '>'; 262ff43ce8SJeremy Kun } 272ff43ce8SJeremy Kun 282ff43ce8SJeremy Kun /// A callable that parses the coefficient using the appropriate method for the 292ff43ce8SJeremy Kun /// given monomial type, and stores the parsed coefficient value on the 302ff43ce8SJeremy Kun /// monomial. 312ff43ce8SJeremy Kun template <typename MonomialType> 322ff43ce8SJeremy Kun using ParseCoefficientFn = std::function<OptionalParseResult(MonomialType &)>; 332ff43ce8SJeremy Kun 3455b6f170SJeremy Kun /// Try to parse a monomial. If successful, populate the fields of the outparam 3555b6f170SJeremy Kun /// `monomial` with the results, and the `variable` outparam with the parsed 3655b6f170SJeremy Kun /// variable name. Sets shouldParseMore to true if the monomial is followed by 3755b6f170SJeremy Kun /// a '+'. 382ff43ce8SJeremy Kun /// 392ff43ce8SJeremy Kun template <typename Monomial> 402ff43ce8SJeremy Kun ParseResult 412ff43ce8SJeremy Kun parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable, 422ff43ce8SJeremy Kun bool &isConstantTerm, bool &shouldParseMore, 432ff43ce8SJeremy Kun ParseCoefficientFn<Monomial> parseAndStoreCoefficient) { 442ff43ce8SJeremy Kun OptionalParseResult parsedCoeffResult = parseAndStoreCoefficient(monomial); 4555b6f170SJeremy Kun 4655b6f170SJeremy Kun isConstantTerm = false; 4755b6f170SJeremy Kun shouldParseMore = false; 4855b6f170SJeremy Kun 4955b6f170SJeremy Kun // A + indicates it's a constant term with more to go, as in `1 + x`. 5055b6f170SJeremy Kun if (succeeded(parser.parseOptionalPlus())) { 5155b6f170SJeremy Kun // If no coefficient was parsed, and there's a +, then it's effectively 5255b6f170SJeremy Kun // parsing an empty string. 5355b6f170SJeremy Kun if (!parsedCoeffResult.has_value()) { 5455b6f170SJeremy Kun return failure(); 5555b6f170SJeremy Kun } 562ff43ce8SJeremy Kun monomial.setExponent(APInt(apintBitWidth, 0)); 5755b6f170SJeremy Kun isConstantTerm = true; 5855b6f170SJeremy Kun shouldParseMore = true; 5955b6f170SJeremy Kun return success(); 6055b6f170SJeremy Kun } 6155b6f170SJeremy Kun 6255b6f170SJeremy Kun // A monomial can be a trailing constant term, as in `x + 1`. 6355b6f170SJeremy Kun if (failed(parser.parseOptionalKeyword(&variable))) { 6455b6f170SJeremy Kun // If neither a coefficient nor a variable was found, then it's effectively 6555b6f170SJeremy Kun // parsing an empty string. 6655b6f170SJeremy Kun if (!parsedCoeffResult.has_value()) { 6755b6f170SJeremy Kun return failure(); 6855b6f170SJeremy Kun } 6955b6f170SJeremy Kun 702ff43ce8SJeremy Kun monomial.setExponent(APInt(apintBitWidth, 0)); 7155b6f170SJeremy Kun isConstantTerm = true; 7255b6f170SJeremy Kun return success(); 7355b6f170SJeremy Kun } 7455b6f170SJeremy Kun 7555b6f170SJeremy Kun // Parse exponentiation symbol as `**`. We can't use caret because it's 7655b6f170SJeremy Kun // reserved for basic block identifiers If no star is present, it's treated 7755b6f170SJeremy Kun // as a polynomial with exponent 1. 7855b6f170SJeremy Kun if (succeeded(parser.parseOptionalStar())) { 7955b6f170SJeremy Kun // If there's one * there must be two. 8055b6f170SJeremy Kun if (failed(parser.parseStar())) { 8155b6f170SJeremy Kun return failure(); 8255b6f170SJeremy Kun } 8355b6f170SJeremy Kun 8455b6f170SJeremy Kun // If there's a **, then the integer exponent is required. 8555b6f170SJeremy Kun APInt parsedExponent(apintBitWidth, 0); 8655b6f170SJeremy Kun if (failed(parser.parseInteger(parsedExponent))) { 8755b6f170SJeremy Kun parser.emitError(parser.getCurrentLocation(), 8855b6f170SJeremy Kun "found invalid integer exponent"); 8955b6f170SJeremy Kun return failure(); 9055b6f170SJeremy Kun } 9155b6f170SJeremy Kun 922ff43ce8SJeremy Kun monomial.setExponent(parsedExponent); 9355b6f170SJeremy Kun } else { 942ff43ce8SJeremy Kun monomial.setExponent(APInt(apintBitWidth, 1)); 9555b6f170SJeremy Kun } 9655b6f170SJeremy Kun 9755b6f170SJeremy Kun if (succeeded(parser.parseOptionalPlus())) { 9855b6f170SJeremy Kun shouldParseMore = true; 9955b6f170SJeremy Kun } 10055b6f170SJeremy Kun return success(); 10155b6f170SJeremy Kun } 10255b6f170SJeremy Kun 103ab29203eSJeremy Kun template <typename Monomial> 1042ff43ce8SJeremy Kun LogicalResult 1052ff43ce8SJeremy Kun parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials, 1062ff43ce8SJeremy Kun llvm::StringSet<> &variables, 1072ff43ce8SJeremy Kun ParseCoefficientFn<Monomial> parseAndStoreCoefficient) { 10855b6f170SJeremy Kun while (true) { 10955b6f170SJeremy Kun Monomial parsedMonomial; 11055b6f170SJeremy Kun llvm::StringRef parsedVariableRef; 11155b6f170SJeremy Kun bool isConstantTerm; 11255b6f170SJeremy Kun bool shouldParseMore; 1132ff43ce8SJeremy Kun if (failed(parseMonomial<Monomial>( 1142ff43ce8SJeremy Kun parser, parsedMonomial, parsedVariableRef, isConstantTerm, 1152ff43ce8SJeremy Kun shouldParseMore, parseAndStoreCoefficient))) { 11655b6f170SJeremy Kun parser.emitError(parser.getCurrentLocation(), "expected a monomial"); 1172ff43ce8SJeremy Kun return failure(); 11855b6f170SJeremy Kun } 11955b6f170SJeremy Kun 12055b6f170SJeremy Kun if (!isConstantTerm) { 12155b6f170SJeremy Kun std::string parsedVariable = parsedVariableRef.str(); 12255b6f170SJeremy Kun variables.insert(parsedVariable); 12355b6f170SJeremy Kun } 12455b6f170SJeremy Kun monomials.push_back(parsedMonomial); 12555b6f170SJeremy Kun 12655b6f170SJeremy Kun if (shouldParseMore) 12755b6f170SJeremy Kun continue; 12855b6f170SJeremy Kun 12955b6f170SJeremy Kun if (succeeded(parser.parseOptionalGreater())) { 13055b6f170SJeremy Kun break; 13155b6f170SJeremy Kun } 13255b6f170SJeremy Kun parser.emitError( 13355b6f170SJeremy Kun parser.getCurrentLocation(), 13455b6f170SJeremy Kun "expected + and more monomials, or > to end polynomial attribute"); 1352ff43ce8SJeremy Kun return failure(); 13655b6f170SJeremy Kun } 13755b6f170SJeremy Kun 13855b6f170SJeremy Kun if (variables.size() > 1) { 13955b6f170SJeremy Kun std::string vars = llvm::join(variables.keys(), ", "); 14055b6f170SJeremy Kun parser.emitError( 14155b6f170SJeremy Kun parser.getCurrentLocation(), 14255b6f170SJeremy Kun "polynomials must have one indeterminate, but there were multiple: " + 14355b6f170SJeremy Kun vars); 1442ff43ce8SJeremy Kun return failure(); 14555b6f170SJeremy Kun } 14655b6f170SJeremy Kun 1472ff43ce8SJeremy Kun return success(); 1482ff43ce8SJeremy Kun } 1492ff43ce8SJeremy Kun 1502ff43ce8SJeremy Kun Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) { 1512ff43ce8SJeremy Kun if (failed(parser.parseLess())) 1522ff43ce8SJeremy Kun return {}; 1532ff43ce8SJeremy Kun 1542ff43ce8SJeremy Kun llvm::SmallVector<IntMonomial> monomials; 1552ff43ce8SJeremy Kun llvm::StringSet<> variables; 1562ff43ce8SJeremy Kun 157ab29203eSJeremy Kun if (failed(parsePolynomialAttr<IntMonomial>( 1582ff43ce8SJeremy Kun parser, monomials, variables, 1592ff43ce8SJeremy Kun [&](IntMonomial &monomial) -> OptionalParseResult { 1602ff43ce8SJeremy Kun APInt parsedCoeff(apintBitWidth, 1); 1612ff43ce8SJeremy Kun OptionalParseResult result = 1622ff43ce8SJeremy Kun parser.parseOptionalInteger(parsedCoeff); 1632ff43ce8SJeremy Kun monomial.setCoefficient(parsedCoeff); 1642ff43ce8SJeremy Kun return result; 1652ff43ce8SJeremy Kun }))) { 1662ff43ce8SJeremy Kun return {}; 1672ff43ce8SJeremy Kun } 1682ff43ce8SJeremy Kun 1692ff43ce8SJeremy Kun auto result = IntPolynomial::fromMonomials(monomials); 17055b6f170SJeremy Kun if (failed(result)) { 17155b6f170SJeremy Kun parser.emitError(parser.getCurrentLocation()) 17255b6f170SJeremy Kun << "parsed polynomial must have unique exponents among monomials"; 17355b6f170SJeremy Kun return {}; 17455b6f170SJeremy Kun } 1752ff43ce8SJeremy Kun return IntPolynomialAttr::get(parser.getContext(), result.value()); 17655b6f170SJeremy Kun } 1772ff43ce8SJeremy Kun Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) { 17855b6f170SJeremy Kun if (failed(parser.parseLess())) 17955b6f170SJeremy Kun return {}; 18055b6f170SJeremy Kun 1812ff43ce8SJeremy Kun llvm::SmallVector<FloatMonomial> monomials; 1822ff43ce8SJeremy Kun llvm::StringSet<> variables; 18355b6f170SJeremy Kun 1842ff43ce8SJeremy Kun ParseCoefficientFn<FloatMonomial> parseAndStoreCoefficient = 1852ff43ce8SJeremy Kun [&](FloatMonomial &monomial) -> OptionalParseResult { 1862ff43ce8SJeremy Kun double coeffValue = 1.0; 1872ff43ce8SJeremy Kun ParseResult result = parser.parseFloat(coeffValue); 1882ff43ce8SJeremy Kun monomial.setCoefficient(APFloat(coeffValue)); 1892ff43ce8SJeremy Kun return OptionalParseResult(result); 1902ff43ce8SJeremy Kun }; 19155b6f170SJeremy Kun 192ab29203eSJeremy Kun if (failed(parsePolynomialAttr<FloatMonomial>(parser, monomials, variables, 193ab29203eSJeremy Kun parseAndStoreCoefficient))) { 19455b6f170SJeremy Kun return {}; 19555b6f170SJeremy Kun } 1962ff43ce8SJeremy Kun 1972ff43ce8SJeremy Kun auto result = FloatPolynomial::fromMonomials(monomials); 19855b6f170SJeremy Kun if (failed(result)) { 1992ff43ce8SJeremy Kun parser.emitError(parser.getCurrentLocation()) 2002ff43ce8SJeremy Kun << "parsed polynomial must have unique exponents among monomials"; 20155b6f170SJeremy Kun return {}; 20255b6f170SJeremy Kun } 2032ff43ce8SJeremy Kun return FloatPolynomialAttr::get(parser.getContext(), result.value()); 20455b6f170SJeremy Kun } 20555b6f170SJeremy Kun 206*4425dfbaSHongren Zheng LogicalResult 207*4425dfbaSHongren Zheng RingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError, 208*4425dfbaSHongren Zheng Type coefficientType, IntegerAttr coefficientModulus, 209*4425dfbaSHongren Zheng IntPolynomialAttr polynomialModulus) { 210*4425dfbaSHongren Zheng if (coefficientModulus) { 211*4425dfbaSHongren Zheng auto coeffIntType = llvm::dyn_cast<IntegerType>(coefficientType); 212*4425dfbaSHongren Zheng if (!coeffIntType) { 213*4425dfbaSHongren Zheng return emitError() << "coefficientModulus specified but coefficientType " 214*4425dfbaSHongren Zheng "is not integral"; 215*4425dfbaSHongren Zheng } 216*4425dfbaSHongren Zheng APInt coeffModValue = coefficientModulus.getValue(); 217*4425dfbaSHongren Zheng if (coeffModValue == 0) { 218*4425dfbaSHongren Zheng return emitError() << "coefficientModulus should not be 0"; 219*4425dfbaSHongren Zheng } 220*4425dfbaSHongren Zheng if (coeffModValue.slt(0)) { 221*4425dfbaSHongren Zheng return emitError() << "coefficientModulus should be positive"; 222*4425dfbaSHongren Zheng } 223*4425dfbaSHongren Zheng auto coeffModWidth = (coeffModValue - 1).getActiveBits(); 224*4425dfbaSHongren Zheng auto coeffWidth = coeffIntType.getWidth(); 225*4425dfbaSHongren Zheng if (coeffModWidth > coeffWidth) { 226*4425dfbaSHongren Zheng return emitError() << "coefficientModulus needs bit width of " 227*4425dfbaSHongren Zheng << coeffModWidth 228*4425dfbaSHongren Zheng << " but coefficientType can only contain " 229*4425dfbaSHongren Zheng << coeffWidth << " bits"; 230*4425dfbaSHongren Zheng } 231*4425dfbaSHongren Zheng } 232*4425dfbaSHongren Zheng return success(); 233*4425dfbaSHongren Zheng } 234*4425dfbaSHongren Zheng 23555b6f170SJeremy Kun } // namespace polynomial 23655b6f170SJeremy Kun } // namespace mlir 237