//===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" #include "mlir/Dialect/Polynomial/IR/Polynomial.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" namespace mlir { namespace polynomial { void IntPolynomialAttr::print(AsmPrinter &p) const { p << '<' << getPolynomial() << '>'; } void FloatPolynomialAttr::print(AsmPrinter &p) const { p << '<' << getPolynomial() << '>'; } /// A callable that parses the coefficient using the appropriate method for the /// given monomial type, and stores the parsed coefficient value on the /// monomial. template using ParseCoefficientFn = std::function; /// Try to parse a monomial. If successful, populate the fields of the outparam /// `monomial` with the results, and the `variable` outparam with the parsed /// variable name. Sets shouldParseMore to true if the monomial is followed by /// a '+'. /// template ParseResult parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable, bool &isConstantTerm, bool &shouldParseMore, ParseCoefficientFn parseAndStoreCoefficient) { OptionalParseResult parsedCoeffResult = parseAndStoreCoefficient(monomial); isConstantTerm = false; shouldParseMore = false; // A + indicates it's a constant term with more to go, as in `1 + x`. if (succeeded(parser.parseOptionalPlus())) { // If no coefficient was parsed, and there's a +, then it's effectively // parsing an empty string. if (!parsedCoeffResult.has_value()) { return failure(); } monomial.setExponent(APInt(apintBitWidth, 0)); isConstantTerm = true; shouldParseMore = true; return success(); } // A monomial can be a trailing constant term, as in `x + 1`. if (failed(parser.parseOptionalKeyword(&variable))) { // If neither a coefficient nor a variable was found, then it's effectively // parsing an empty string. if (!parsedCoeffResult.has_value()) { return failure(); } monomial.setExponent(APInt(apintBitWidth, 0)); isConstantTerm = true; return success(); } // Parse exponentiation symbol as `**`. We can't use caret because it's // reserved for basic block identifiers If no star is present, it's treated // as a polynomial with exponent 1. if (succeeded(parser.parseOptionalStar())) { // If there's one * there must be two. if (failed(parser.parseStar())) { return failure(); } // If there's a **, then the integer exponent is required. APInt parsedExponent(apintBitWidth, 0); if (failed(parser.parseInteger(parsedExponent))) { parser.emitError(parser.getCurrentLocation(), "found invalid integer exponent"); return failure(); } monomial.setExponent(parsedExponent); } else { monomial.setExponent(APInt(apintBitWidth, 1)); } if (succeeded(parser.parseOptionalPlus())) { shouldParseMore = true; } return success(); } template LogicalResult parsePolynomialAttr(AsmParser &parser, llvm::SmallVector &monomials, llvm::StringSet<> &variables, ParseCoefficientFn parseAndStoreCoefficient) { while (true) { Monomial parsedMonomial; llvm::StringRef parsedVariableRef; bool isConstantTerm; bool shouldParseMore; if (failed(parseMonomial( parser, parsedMonomial, parsedVariableRef, isConstantTerm, shouldParseMore, parseAndStoreCoefficient))) { parser.emitError(parser.getCurrentLocation(), "expected a monomial"); return failure(); } if (!isConstantTerm) { std::string parsedVariable = parsedVariableRef.str(); variables.insert(parsedVariable); } monomials.push_back(parsedMonomial); if (shouldParseMore) continue; if (succeeded(parser.parseOptionalGreater())) { break; } parser.emitError( parser.getCurrentLocation(), "expected + and more monomials, or > to end polynomial attribute"); return failure(); } if (variables.size() > 1) { std::string vars = llvm::join(variables.keys(), ", "); parser.emitError( parser.getCurrentLocation(), "polynomials must have one indeterminate, but there were multiple: " + vars); return failure(); } return success(); } Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) { if (failed(parser.parseLess())) return {}; llvm::SmallVector monomials; llvm::StringSet<> variables; if (failed(parsePolynomialAttr( parser, monomials, variables, [&](IntMonomial &monomial) -> OptionalParseResult { APInt parsedCoeff(apintBitWidth, 1); OptionalParseResult result = parser.parseOptionalInteger(parsedCoeff); monomial.setCoefficient(parsedCoeff); return result; }))) { return {}; } auto result = IntPolynomial::fromMonomials(monomials); if (failed(result)) { parser.emitError(parser.getCurrentLocation()) << "parsed polynomial must have unique exponents among monomials"; return {}; } return IntPolynomialAttr::get(parser.getContext(), result.value()); } Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) { if (failed(parser.parseLess())) return {}; llvm::SmallVector monomials; llvm::StringSet<> variables; ParseCoefficientFn parseAndStoreCoefficient = [&](FloatMonomial &monomial) -> OptionalParseResult { double coeffValue = 1.0; ParseResult result = parser.parseFloat(coeffValue); monomial.setCoefficient(APFloat(coeffValue)); return OptionalParseResult(result); }; if (failed(parsePolynomialAttr(parser, monomials, variables, parseAndStoreCoefficient))) { return {}; } auto result = FloatPolynomial::fromMonomials(monomials); if (failed(result)) { parser.emitError(parser.getCurrentLocation()) << "parsed polynomial must have unique exponents among monomials"; return {}; } return FloatPolynomialAttr::get(parser.getContext(), result.value()); } LogicalResult RingAttr::verify(function_ref emitError, Type coefficientType, IntegerAttr coefficientModulus, IntPolynomialAttr polynomialModulus) { if (coefficientModulus) { auto coeffIntType = llvm::dyn_cast(coefficientType); if (!coeffIntType) { return emitError() << "coefficientModulus specified but coefficientType " "is not integral"; } APInt coeffModValue = coefficientModulus.getValue(); if (coeffModValue == 0) { return emitError() << "coefficientModulus should not be 0"; } if (coeffModValue.slt(0)) { return emitError() << "coefficientModulus should be positive"; } auto coeffModWidth = (coeffModValue - 1).getActiveBits(); auto coeffWidth = coeffIntType.getWidth(); if (coeffModWidth > coeffWidth) { return emitError() << "coefficientModulus needs bit width of " << coeffModWidth << " but coefficientType can only contain " << coeffWidth << " bits"; } } return success(); } } // namespace polynomial } // namespace mlir