1 //===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" 9 10 #include "mlir/Dialect/Polynomial/IR/Polynomial.h" 11 #include "mlir/Support/LLVM.h" 12 #include "llvm/ADT/SmallVector.h" 13 #include "llvm/ADT/StringExtras.h" 14 #include "llvm/ADT/StringRef.h" 15 #include "llvm/ADT/StringSet.h" 16 17 namespace mlir { 18 namespace polynomial { 19 20 void IntPolynomialAttr::print(AsmPrinter &p) const { 21 p << '<' << getPolynomial() << '>'; 22 } 23 24 void FloatPolynomialAttr::print(AsmPrinter &p) const { 25 p << '<' << getPolynomial() << '>'; 26 } 27 28 /// A callable that parses the coefficient using the appropriate method for the 29 /// given monomial type, and stores the parsed coefficient value on the 30 /// monomial. 31 template <typename MonomialType> 32 using ParseCoefficientFn = std::function<OptionalParseResult(MonomialType &)>; 33 34 /// Try to parse a monomial. If successful, populate the fields of the outparam 35 /// `monomial` with the results, and the `variable` outparam with the parsed 36 /// variable name. Sets shouldParseMore to true if the monomial is followed by 37 /// a '+'. 38 /// 39 template <typename Monomial> 40 ParseResult 41 parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable, 42 bool &isConstantTerm, bool &shouldParseMore, 43 ParseCoefficientFn<Monomial> parseAndStoreCoefficient) { 44 OptionalParseResult parsedCoeffResult = parseAndStoreCoefficient(monomial); 45 46 isConstantTerm = false; 47 shouldParseMore = false; 48 49 // A + indicates it's a constant term with more to go, as in `1 + x`. 50 if (succeeded(parser.parseOptionalPlus())) { 51 // If no coefficient was parsed, and there's a +, then it's effectively 52 // parsing an empty string. 53 if (!parsedCoeffResult.has_value()) { 54 return failure(); 55 } 56 monomial.setExponent(APInt(apintBitWidth, 0)); 57 isConstantTerm = true; 58 shouldParseMore = true; 59 return success(); 60 } 61 62 // A monomial can be a trailing constant term, as in `x + 1`. 63 if (failed(parser.parseOptionalKeyword(&variable))) { 64 // If neither a coefficient nor a variable was found, then it's effectively 65 // parsing an empty string. 66 if (!parsedCoeffResult.has_value()) { 67 return failure(); 68 } 69 70 monomial.setExponent(APInt(apintBitWidth, 0)); 71 isConstantTerm = true; 72 return success(); 73 } 74 75 // Parse exponentiation symbol as `**`. We can't use caret because it's 76 // reserved for basic block identifiers If no star is present, it's treated 77 // as a polynomial with exponent 1. 78 if (succeeded(parser.parseOptionalStar())) { 79 // If there's one * there must be two. 80 if (failed(parser.parseStar())) { 81 return failure(); 82 } 83 84 // If there's a **, then the integer exponent is required. 85 APInt parsedExponent(apintBitWidth, 0); 86 if (failed(parser.parseInteger(parsedExponent))) { 87 parser.emitError(parser.getCurrentLocation(), 88 "found invalid integer exponent"); 89 return failure(); 90 } 91 92 monomial.setExponent(parsedExponent); 93 } else { 94 monomial.setExponent(APInt(apintBitWidth, 1)); 95 } 96 97 if (succeeded(parser.parseOptionalPlus())) { 98 shouldParseMore = true; 99 } 100 return success(); 101 } 102 103 template <typename Monomial> 104 LogicalResult 105 parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials, 106 llvm::StringSet<> &variables, 107 ParseCoefficientFn<Monomial> parseAndStoreCoefficient) { 108 while (true) { 109 Monomial parsedMonomial; 110 llvm::StringRef parsedVariableRef; 111 bool isConstantTerm; 112 bool shouldParseMore; 113 if (failed(parseMonomial<Monomial>( 114 parser, parsedMonomial, parsedVariableRef, isConstantTerm, 115 shouldParseMore, parseAndStoreCoefficient))) { 116 parser.emitError(parser.getCurrentLocation(), "expected a monomial"); 117 return failure(); 118 } 119 120 if (!isConstantTerm) { 121 std::string parsedVariable = parsedVariableRef.str(); 122 variables.insert(parsedVariable); 123 } 124 monomials.push_back(parsedMonomial); 125 126 if (shouldParseMore) 127 continue; 128 129 if (succeeded(parser.parseOptionalGreater())) { 130 break; 131 } 132 parser.emitError( 133 parser.getCurrentLocation(), 134 "expected + and more monomials, or > to end polynomial attribute"); 135 return failure(); 136 } 137 138 if (variables.size() > 1) { 139 std::string vars = llvm::join(variables.keys(), ", "); 140 parser.emitError( 141 parser.getCurrentLocation(), 142 "polynomials must have one indeterminate, but there were multiple: " + 143 vars); 144 return failure(); 145 } 146 147 return success(); 148 } 149 150 Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) { 151 if (failed(parser.parseLess())) 152 return {}; 153 154 llvm::SmallVector<IntMonomial> monomials; 155 llvm::StringSet<> variables; 156 157 if (failed(parsePolynomialAttr<IntMonomial>( 158 parser, monomials, variables, 159 [&](IntMonomial &monomial) -> OptionalParseResult { 160 APInt parsedCoeff(apintBitWidth, 1); 161 OptionalParseResult result = 162 parser.parseOptionalInteger(parsedCoeff); 163 monomial.setCoefficient(parsedCoeff); 164 return result; 165 }))) { 166 return {}; 167 } 168 169 auto result = IntPolynomial::fromMonomials(monomials); 170 if (failed(result)) { 171 parser.emitError(parser.getCurrentLocation()) 172 << "parsed polynomial must have unique exponents among monomials"; 173 return {}; 174 } 175 return IntPolynomialAttr::get(parser.getContext(), result.value()); 176 } 177 Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) { 178 if (failed(parser.parseLess())) 179 return {}; 180 181 llvm::SmallVector<FloatMonomial> monomials; 182 llvm::StringSet<> variables; 183 184 ParseCoefficientFn<FloatMonomial> parseAndStoreCoefficient = 185 [&](FloatMonomial &monomial) -> OptionalParseResult { 186 double coeffValue = 1.0; 187 ParseResult result = parser.parseFloat(coeffValue); 188 monomial.setCoefficient(APFloat(coeffValue)); 189 return OptionalParseResult(result); 190 }; 191 192 if (failed(parsePolynomialAttr<FloatMonomial>(parser, monomials, variables, 193 parseAndStoreCoefficient))) { 194 return {}; 195 } 196 197 auto result = FloatPolynomial::fromMonomials(monomials); 198 if (failed(result)) { 199 parser.emitError(parser.getCurrentLocation()) 200 << "parsed polynomial must have unique exponents among monomials"; 201 return {}; 202 } 203 return FloatPolynomialAttr::get(parser.getContext(), result.value()); 204 } 205 206 LogicalResult 207 RingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError, 208 Type coefficientType, IntegerAttr coefficientModulus, 209 IntPolynomialAttr polynomialModulus) { 210 if (coefficientModulus) { 211 auto coeffIntType = llvm::dyn_cast<IntegerType>(coefficientType); 212 if (!coeffIntType) { 213 return emitError() << "coefficientModulus specified but coefficientType " 214 "is not integral"; 215 } 216 APInt coeffModValue = coefficientModulus.getValue(); 217 if (coeffModValue == 0) { 218 return emitError() << "coefficientModulus should not be 0"; 219 } 220 if (coeffModValue.slt(0)) { 221 return emitError() << "coefficientModulus should be positive"; 222 } 223 auto coeffModWidth = (coeffModValue - 1).getActiveBits(); 224 auto coeffWidth = coeffIntType.getWidth(); 225 if (coeffModWidth > coeffWidth) { 226 return emitError() << "coefficientModulus needs bit width of " 227 << coeffModWidth 228 << " but coefficientType can only contain " 229 << coeffWidth << " bits"; 230 } 231 } 232 return success(); 233 } 234 235 } // namespace polynomial 236 } // namespace mlir 237