1 //===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" 9 10 #include "mlir/Dialect/Polynomial/IR/Polynomial.h" 11 #include "mlir/Support/LLVM.h" 12 #include "mlir/Support/LogicalResult.h" 13 #include "llvm/ADT/StringExtras.h" 14 #include "llvm/ADT/StringRef.h" 15 #include "llvm/ADT/StringSet.h" 16 17 namespace mlir { 18 namespace polynomial { 19 20 void PolynomialAttr::print(AsmPrinter &p) const { 21 p << '<'; 22 p << getPolynomial(); 23 p << '>'; 24 } 25 26 /// Try to parse a monomial. If successful, populate the fields of the outparam 27 /// `monomial` with the results, and the `variable` outparam with the parsed 28 /// variable name. Sets shouldParseMore to true if the monomial is followed by 29 /// a '+'. 30 ParseResult parseMonomial(AsmParser &parser, Monomial &monomial, 31 llvm::StringRef &variable, bool &isConstantTerm, 32 bool &shouldParseMore) { 33 APInt parsedCoeff(apintBitWidth, 1); 34 auto parsedCoeffResult = parser.parseOptionalInteger(parsedCoeff); 35 monomial.coefficient = parsedCoeff; 36 37 isConstantTerm = false; 38 shouldParseMore = false; 39 40 // A + indicates it's a constant term with more to go, as in `1 + x`. 41 if (succeeded(parser.parseOptionalPlus())) { 42 // If no coefficient was parsed, and there's a +, then it's effectively 43 // parsing an empty string. 44 if (!parsedCoeffResult.has_value()) { 45 return failure(); 46 } 47 monomial.exponent = APInt(apintBitWidth, 0); 48 isConstantTerm = true; 49 shouldParseMore = true; 50 return success(); 51 } 52 53 // A monomial can be a trailing constant term, as in `x + 1`. 54 if (failed(parser.parseOptionalKeyword(&variable))) { 55 // If neither a coefficient nor a variable was found, then it's effectively 56 // parsing an empty string. 57 if (!parsedCoeffResult.has_value()) { 58 return failure(); 59 } 60 61 monomial.exponent = APInt(apintBitWidth, 0); 62 isConstantTerm = true; 63 return success(); 64 } 65 66 // Parse exponentiation symbol as `**`. We can't use caret because it's 67 // reserved for basic block identifiers If no star is present, it's treated 68 // as a polynomial with exponent 1. 69 if (succeeded(parser.parseOptionalStar())) { 70 // If there's one * there must be two. 71 if (failed(parser.parseStar())) { 72 return failure(); 73 } 74 75 // If there's a **, then the integer exponent is required. 76 APInt parsedExponent(apintBitWidth, 0); 77 if (failed(parser.parseInteger(parsedExponent))) { 78 parser.emitError(parser.getCurrentLocation(), 79 "found invalid integer exponent"); 80 return failure(); 81 } 82 83 monomial.exponent = parsedExponent; 84 } else { 85 monomial.exponent = APInt(apintBitWidth, 1); 86 } 87 88 if (succeeded(parser.parseOptionalPlus())) { 89 shouldParseMore = true; 90 } 91 return success(); 92 } 93 94 Attribute PolynomialAttr::parse(AsmParser &parser, Type type) { 95 if (failed(parser.parseLess())) 96 return {}; 97 98 llvm::SmallVector<Monomial> monomials; 99 llvm::StringSet<> variables; 100 101 while (true) { 102 Monomial parsedMonomial; 103 llvm::StringRef parsedVariableRef; 104 bool isConstantTerm; 105 bool shouldParseMore; 106 if (failed(parseMonomial(parser, parsedMonomial, parsedVariableRef, 107 isConstantTerm, shouldParseMore))) { 108 parser.emitError(parser.getCurrentLocation(), "expected a monomial"); 109 return {}; 110 } 111 112 if (!isConstantTerm) { 113 std::string parsedVariable = parsedVariableRef.str(); 114 variables.insert(parsedVariable); 115 } 116 monomials.push_back(parsedMonomial); 117 118 if (shouldParseMore) 119 continue; 120 121 if (succeeded(parser.parseOptionalGreater())) { 122 break; 123 } 124 parser.emitError( 125 parser.getCurrentLocation(), 126 "expected + and more monomials, or > to end polynomial attribute"); 127 return {}; 128 } 129 130 if (variables.size() > 1) { 131 std::string vars = llvm::join(variables.keys(), ", "); 132 parser.emitError( 133 parser.getCurrentLocation(), 134 "polynomials must have one indeterminate, but there were multiple: " + 135 vars); 136 } 137 138 auto result = Polynomial::fromMonomials(monomials); 139 if (failed(result)) { 140 parser.emitError(parser.getCurrentLocation()) 141 << "parsed polynomial must have unique exponents among monomials"; 142 return {}; 143 } 144 return PolynomialAttr::get(parser.getContext(), result.value()); 145 } 146 147 void RingAttr::print(AsmPrinter &p) const { 148 p << "#polynomial.ring<coefficientType=" << getCoefficientType() 149 << ", coefficientModulus=" << getCoefficientModulus() 150 << ", polynomialModulus=" << getPolynomialModulus() << '>'; 151 } 152 153 Attribute RingAttr::parse(AsmParser &parser, Type type) { 154 if (failed(parser.parseLess())) 155 return {}; 156 157 if (failed(parser.parseKeyword("coefficientType"))) 158 return {}; 159 160 if (failed(parser.parseEqual())) 161 return {}; 162 163 Type ty; 164 if (failed(parser.parseType(ty))) 165 return {}; 166 167 if (failed(parser.parseComma())) 168 return {}; 169 170 IntegerAttr coefficientModulusAttr = nullptr; 171 if (succeeded(parser.parseKeyword("coefficientModulus"))) { 172 if (failed(parser.parseEqual())) 173 return {}; 174 175 IntegerType iType = mlir::dyn_cast<IntegerType>(ty); 176 if (!iType) { 177 parser.emitError(parser.getCurrentLocation(), 178 "coefficientType must specify an integer type"); 179 return {}; 180 } 181 APInt coefficientModulus(iType.getWidth(), 0); 182 auto result = parser.parseInteger(coefficientModulus); 183 if (failed(result)) { 184 parser.emitError(parser.getCurrentLocation(), 185 "invalid coefficient modulus"); 186 return {}; 187 } 188 coefficientModulusAttr = IntegerAttr::get(iType, coefficientModulus); 189 190 if (failed(parser.parseComma())) 191 return {}; 192 } 193 194 PolynomialAttr polyAttr = nullptr; 195 if (succeeded(parser.parseKeyword("polynomialModulus"))) { 196 if (failed(parser.parseEqual())) 197 return {}; 198 199 PolynomialAttr attr; 200 if (failed(parser.parseAttribute<PolynomialAttr>(attr))) 201 return {}; 202 polyAttr = attr; 203 } 204 205 Polynomial poly = polyAttr.getPolynomial(); 206 APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0); 207 IntegerAttr rootAttr = nullptr; 208 if (succeeded(parser.parseOptionalComma())) { 209 if (failed(parser.parseKeyword("primitiveRoot")) || 210 failed(parser.parseEqual())) 211 return {}; 212 213 ParseResult result = parser.parseInteger(root); 214 if (failed(result)) { 215 parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot"); 216 return {}; 217 } 218 rootAttr = IntegerAttr::get(coefficientModulusAttr.getType(), root); 219 } 220 221 if (failed(parser.parseGreater())) 222 return {}; 223 224 return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr, 225 polyAttr, rootAttr); 226 } 227 228 } // namespace polynomial 229 } // namespace mlir 230