xref: /llvm-project/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp (revision 624c9fc87fb23c6eb89a28e95fa0cf72f5c39d94)
1 //===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
9 
10 #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
11 #include "mlir/Support/LLVM.h"
12 #include "mlir/Support/LogicalResult.h"
13 #include "llvm/ADT/StringExtras.h"
14 #include "llvm/ADT/StringRef.h"
15 #include "llvm/ADT/StringSet.h"
16 
17 namespace mlir {
18 namespace polynomial {
19 
20 void PolynomialAttr::print(AsmPrinter &p) const {
21   p << '<';
22   p << getPolynomial();
23   p << '>';
24 }
25 
26 /// Try to parse a monomial. If successful, populate the fields of the outparam
27 /// `monomial` with the results, and the `variable` outparam with the parsed
28 /// variable name. Sets shouldParseMore to true if the monomial is followed by
29 /// a '+'.
30 ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
31                           llvm::StringRef &variable, bool &isConstantTerm,
32                           bool &shouldParseMore) {
33   APInt parsedCoeff(apintBitWidth, 1);
34   auto parsedCoeffResult = parser.parseOptionalInteger(parsedCoeff);
35   monomial.coefficient = parsedCoeff;
36 
37   isConstantTerm = false;
38   shouldParseMore = false;
39 
40   // A + indicates it's a constant term with more to go, as in `1 + x`.
41   if (succeeded(parser.parseOptionalPlus())) {
42     // If no coefficient was parsed, and there's a +, then it's effectively
43     // parsing an empty string.
44     if (!parsedCoeffResult.has_value()) {
45       return failure();
46     }
47     monomial.exponent = APInt(apintBitWidth, 0);
48     isConstantTerm = true;
49     shouldParseMore = true;
50     return success();
51   }
52 
53   // A monomial can be a trailing constant term, as in `x + 1`.
54   if (failed(parser.parseOptionalKeyword(&variable))) {
55     // If neither a coefficient nor a variable was found, then it's effectively
56     // parsing an empty string.
57     if (!parsedCoeffResult.has_value()) {
58       return failure();
59     }
60 
61     monomial.exponent = APInt(apintBitWidth, 0);
62     isConstantTerm = true;
63     return success();
64   }
65 
66   // Parse exponentiation symbol as `**`. We can't use caret because it's
67   // reserved for basic block identifiers If no star is present, it's treated
68   // as a polynomial with exponent 1.
69   if (succeeded(parser.parseOptionalStar())) {
70     // If there's one * there must be two.
71     if (failed(parser.parseStar())) {
72       return failure();
73     }
74 
75     // If there's a **, then the integer exponent is required.
76     APInt parsedExponent(apintBitWidth, 0);
77     if (failed(parser.parseInteger(parsedExponent))) {
78       parser.emitError(parser.getCurrentLocation(),
79                        "found invalid integer exponent");
80       return failure();
81     }
82 
83     monomial.exponent = parsedExponent;
84   } else {
85     monomial.exponent = APInt(apintBitWidth, 1);
86   }
87 
88   if (succeeded(parser.parseOptionalPlus())) {
89     shouldParseMore = true;
90   }
91   return success();
92 }
93 
94 Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
95   if (failed(parser.parseLess()))
96     return {};
97 
98   llvm::SmallVector<Monomial> monomials;
99   llvm::StringSet<> variables;
100 
101   while (true) {
102     Monomial parsedMonomial;
103     llvm::StringRef parsedVariableRef;
104     bool isConstantTerm;
105     bool shouldParseMore;
106     if (failed(parseMonomial(parser, parsedMonomial, parsedVariableRef,
107                              isConstantTerm, shouldParseMore))) {
108       parser.emitError(parser.getCurrentLocation(), "expected a monomial");
109       return {};
110     }
111 
112     if (!isConstantTerm) {
113       std::string parsedVariable = parsedVariableRef.str();
114       variables.insert(parsedVariable);
115     }
116     monomials.push_back(parsedMonomial);
117 
118     if (shouldParseMore)
119       continue;
120 
121     if (succeeded(parser.parseOptionalGreater())) {
122       break;
123     }
124     parser.emitError(
125         parser.getCurrentLocation(),
126         "expected + and more monomials, or > to end polynomial attribute");
127     return {};
128   }
129 
130   if (variables.size() > 1) {
131     std::string vars = llvm::join(variables.keys(), ", ");
132     parser.emitError(
133         parser.getCurrentLocation(),
134         "polynomials must have one indeterminate, but there were multiple: " +
135             vars);
136   }
137 
138   auto result = Polynomial::fromMonomials(monomials);
139   if (failed(result)) {
140     parser.emitError(parser.getCurrentLocation())
141         << "parsed polynomial must have unique exponents among monomials";
142     return {};
143   }
144   return PolynomialAttr::get(parser.getContext(), result.value());
145 }
146 
147 void RingAttr::print(AsmPrinter &p) const {
148   p << "#polynomial.ring<coefficientType=" << getCoefficientType()
149     << ", coefficientModulus=" << getCoefficientModulus()
150     << ", polynomialModulus=" << getPolynomialModulus() << '>';
151 }
152 
153 Attribute RingAttr::parse(AsmParser &parser, Type type) {
154   if (failed(parser.parseLess()))
155     return {};
156 
157   if (failed(parser.parseKeyword("coefficientType")))
158     return {};
159 
160   if (failed(parser.parseEqual()))
161     return {};
162 
163   Type ty;
164   if (failed(parser.parseType(ty)))
165     return {};
166 
167   if (failed(parser.parseComma()))
168     return {};
169 
170   IntegerAttr coefficientModulusAttr = nullptr;
171   if (succeeded(parser.parseKeyword("coefficientModulus"))) {
172     if (failed(parser.parseEqual()))
173       return {};
174 
175     IntegerType iType = mlir::dyn_cast<IntegerType>(ty);
176     if (!iType) {
177       parser.emitError(parser.getCurrentLocation(),
178                        "coefficientType must specify an integer type");
179       return {};
180     }
181     APInt coefficientModulus(iType.getWidth(), 0);
182     auto result = parser.parseInteger(coefficientModulus);
183     if (failed(result)) {
184       parser.emitError(parser.getCurrentLocation(),
185                        "invalid coefficient modulus");
186       return {};
187     }
188     coefficientModulusAttr = IntegerAttr::get(iType, coefficientModulus);
189 
190     if (failed(parser.parseComma()))
191       return {};
192   }
193 
194   PolynomialAttr polyAttr = nullptr;
195   if (succeeded(parser.parseKeyword("polynomialModulus"))) {
196     if (failed(parser.parseEqual()))
197       return {};
198 
199     PolynomialAttr attr;
200     if (failed(parser.parseAttribute<PolynomialAttr>(attr)))
201       return {};
202     polyAttr = attr;
203   }
204 
205   Polynomial poly = polyAttr.getPolynomial();
206   APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
207   IntegerAttr rootAttr = nullptr;
208   if (succeeded(parser.parseOptionalComma())) {
209     if (failed(parser.parseKeyword("primitiveRoot")) ||
210         failed(parser.parseEqual()))
211       return {};
212 
213     ParseResult result = parser.parseInteger(root);
214     if (failed(result)) {
215       parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot");
216       return {};
217     }
218     rootAttr = IntegerAttr::get(coefficientModulusAttr.getType(), root);
219   }
220 
221   if (failed(parser.parseGreater()))
222     return {};
223 
224   return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
225                        polyAttr, rootAttr);
226 }
227 
228 } // namespace polynomial
229 } // namespace mlir
230