//===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" #include "mlir/Dialect/Polynomial/IR/Polynomial.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" namespace mlir { namespace polynomial { void PolynomialAttr::print(AsmPrinter &p) const { p << '<'; p << getPolynomial(); p << '>'; } /// Try to parse a monomial. If successful, populate the fields of the outparam /// `monomial` with the results, and the `variable` outparam with the parsed /// variable name. Sets shouldParseMore to true if the monomial is followed by /// a '+'. ParseResult parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable, bool &isConstantTerm, bool &shouldParseMore) { APInt parsedCoeff(apintBitWidth, 1); auto parsedCoeffResult = parser.parseOptionalInteger(parsedCoeff); monomial.coefficient = parsedCoeff; isConstantTerm = false; shouldParseMore = false; // A + indicates it's a constant term with more to go, as in `1 + x`. if (succeeded(parser.parseOptionalPlus())) { // If no coefficient was parsed, and there's a +, then it's effectively // parsing an empty string. if (!parsedCoeffResult.has_value()) { return failure(); } monomial.exponent = APInt(apintBitWidth, 0); isConstantTerm = true; shouldParseMore = true; return success(); } // A monomial can be a trailing constant term, as in `x + 1`. if (failed(parser.parseOptionalKeyword(&variable))) { // If neither a coefficient nor a variable was found, then it's effectively // parsing an empty string. if (!parsedCoeffResult.has_value()) { return failure(); } monomial.exponent = APInt(apintBitWidth, 0); isConstantTerm = true; return success(); } // Parse exponentiation symbol as `**`. We can't use caret because it's // reserved for basic block identifiers If no star is present, it's treated // as a polynomial with exponent 1. if (succeeded(parser.parseOptionalStar())) { // If there's one * there must be two. if (failed(parser.parseStar())) { return failure(); } // If there's a **, then the integer exponent is required. APInt parsedExponent(apintBitWidth, 0); if (failed(parser.parseInteger(parsedExponent))) { parser.emitError(parser.getCurrentLocation(), "found invalid integer exponent"); return failure(); } monomial.exponent = parsedExponent; } else { monomial.exponent = APInt(apintBitWidth, 1); } if (succeeded(parser.parseOptionalPlus())) { shouldParseMore = true; } return success(); } Attribute PolynomialAttr::parse(AsmParser &parser, Type type) { if (failed(parser.parseLess())) return {}; llvm::SmallVector monomials; llvm::StringSet<> variables; while (true) { Monomial parsedMonomial; llvm::StringRef parsedVariableRef; bool isConstantTerm; bool shouldParseMore; if (failed(parseMonomial(parser, parsedMonomial, parsedVariableRef, isConstantTerm, shouldParseMore))) { parser.emitError(parser.getCurrentLocation(), "expected a monomial"); return {}; } if (!isConstantTerm) { std::string parsedVariable = parsedVariableRef.str(); variables.insert(parsedVariable); } monomials.push_back(parsedMonomial); if (shouldParseMore) continue; if (succeeded(parser.parseOptionalGreater())) { break; } parser.emitError( parser.getCurrentLocation(), "expected + and more monomials, or > to end polynomial attribute"); return {}; } if (variables.size() > 1) { std::string vars = llvm::join(variables.keys(), ", "); parser.emitError( parser.getCurrentLocation(), "polynomials must have one indeterminate, but there were multiple: " + vars); } auto result = Polynomial::fromMonomials(monomials); if (failed(result)) { parser.emitError(parser.getCurrentLocation()) << "parsed polynomial must have unique exponents among monomials"; return {}; } return PolynomialAttr::get(parser.getContext(), result.value()); } void RingAttr::print(AsmPrinter &p) const { p << "#polynomial.ring(ty); if (!iType) { parser.emitError(parser.getCurrentLocation(), "coefficientType must specify an integer type"); return {}; } APInt coefficientModulus(iType.getWidth(), 0); auto result = parser.parseInteger(coefficientModulus); if (failed(result)) { parser.emitError(parser.getCurrentLocation(), "invalid coefficient modulus"); return {}; } coefficientModulusAttr = IntegerAttr::get(iType, coefficientModulus); if (failed(parser.parseComma())) return {}; } PolynomialAttr polyAttr = nullptr; if (succeeded(parser.parseKeyword("polynomialModulus"))) { if (failed(parser.parseEqual())) return {}; PolynomialAttr attr; if (failed(parser.parseAttribute(attr))) return {}; polyAttr = attr; } if (failed(parser.parseGreater())) return {}; return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr, polyAttr); } } // namespace polynomial } // namespace mlir