xref: /llvm-project/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp (revision 4425dfba6a1f394e958e94aa471a07bcf707136a)
1 //===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
9 
10 #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
11 #include "mlir/Support/LLVM.h"
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/ADT/StringExtras.h"
14 #include "llvm/ADT/StringRef.h"
15 #include "llvm/ADT/StringSet.h"
16 
17 namespace mlir {
18 namespace polynomial {
19 
20 void IntPolynomialAttr::print(AsmPrinter &p) const {
21   p << '<' << getPolynomial() << '>';
22 }
23 
24 void FloatPolynomialAttr::print(AsmPrinter &p) const {
25   p << '<' << getPolynomial() << '>';
26 }
27 
28 /// A callable that parses the coefficient using the appropriate method for the
29 /// given monomial type, and stores the parsed coefficient value on the
30 /// monomial.
31 template <typename MonomialType>
32 using ParseCoefficientFn = std::function<OptionalParseResult(MonomialType &)>;
33 
34 /// Try to parse a monomial. If successful, populate the fields of the outparam
35 /// `monomial` with the results, and the `variable` outparam with the parsed
36 /// variable name. Sets shouldParseMore to true if the monomial is followed by
37 /// a '+'.
38 ///
39 template <typename Monomial>
40 ParseResult
41 parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
42               bool &isConstantTerm, bool &shouldParseMore,
43               ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
44   OptionalParseResult parsedCoeffResult = parseAndStoreCoefficient(monomial);
45 
46   isConstantTerm = false;
47   shouldParseMore = false;
48 
49   // A + indicates it's a constant term with more to go, as in `1 + x`.
50   if (succeeded(parser.parseOptionalPlus())) {
51     // If no coefficient was parsed, and there's a +, then it's effectively
52     // parsing an empty string.
53     if (!parsedCoeffResult.has_value()) {
54       return failure();
55     }
56     monomial.setExponent(APInt(apintBitWidth, 0));
57     isConstantTerm = true;
58     shouldParseMore = true;
59     return success();
60   }
61 
62   // A monomial can be a trailing constant term, as in `x + 1`.
63   if (failed(parser.parseOptionalKeyword(&variable))) {
64     // If neither a coefficient nor a variable was found, then it's effectively
65     // parsing an empty string.
66     if (!parsedCoeffResult.has_value()) {
67       return failure();
68     }
69 
70     monomial.setExponent(APInt(apintBitWidth, 0));
71     isConstantTerm = true;
72     return success();
73   }
74 
75   // Parse exponentiation symbol as `**`. We can't use caret because it's
76   // reserved for basic block identifiers If no star is present, it's treated
77   // as a polynomial with exponent 1.
78   if (succeeded(parser.parseOptionalStar())) {
79     // If there's one * there must be two.
80     if (failed(parser.parseStar())) {
81       return failure();
82     }
83 
84     // If there's a **, then the integer exponent is required.
85     APInt parsedExponent(apintBitWidth, 0);
86     if (failed(parser.parseInteger(parsedExponent))) {
87       parser.emitError(parser.getCurrentLocation(),
88                        "found invalid integer exponent");
89       return failure();
90     }
91 
92     monomial.setExponent(parsedExponent);
93   } else {
94     monomial.setExponent(APInt(apintBitWidth, 1));
95   }
96 
97   if (succeeded(parser.parseOptionalPlus())) {
98     shouldParseMore = true;
99   }
100   return success();
101 }
102 
103 template <typename Monomial>
104 LogicalResult
105 parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
106                     llvm::StringSet<> &variables,
107                     ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
108   while (true) {
109     Monomial parsedMonomial;
110     llvm::StringRef parsedVariableRef;
111     bool isConstantTerm;
112     bool shouldParseMore;
113     if (failed(parseMonomial<Monomial>(
114             parser, parsedMonomial, parsedVariableRef, isConstantTerm,
115             shouldParseMore, parseAndStoreCoefficient))) {
116       parser.emitError(parser.getCurrentLocation(), "expected a monomial");
117       return failure();
118     }
119 
120     if (!isConstantTerm) {
121       std::string parsedVariable = parsedVariableRef.str();
122       variables.insert(parsedVariable);
123     }
124     monomials.push_back(parsedMonomial);
125 
126     if (shouldParseMore)
127       continue;
128 
129     if (succeeded(parser.parseOptionalGreater())) {
130       break;
131     }
132     parser.emitError(
133         parser.getCurrentLocation(),
134         "expected + and more monomials, or > to end polynomial attribute");
135     return failure();
136   }
137 
138   if (variables.size() > 1) {
139     std::string vars = llvm::join(variables.keys(), ", ");
140     parser.emitError(
141         parser.getCurrentLocation(),
142         "polynomials must have one indeterminate, but there were multiple: " +
143             vars);
144     return failure();
145   }
146 
147   return success();
148 }
149 
150 Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) {
151   if (failed(parser.parseLess()))
152     return {};
153 
154   llvm::SmallVector<IntMonomial> monomials;
155   llvm::StringSet<> variables;
156 
157   if (failed(parsePolynomialAttr<IntMonomial>(
158           parser, monomials, variables,
159           [&](IntMonomial &monomial) -> OptionalParseResult {
160             APInt parsedCoeff(apintBitWidth, 1);
161             OptionalParseResult result =
162                 parser.parseOptionalInteger(parsedCoeff);
163             monomial.setCoefficient(parsedCoeff);
164             return result;
165           }))) {
166     return {};
167   }
168 
169   auto result = IntPolynomial::fromMonomials(monomials);
170   if (failed(result)) {
171     parser.emitError(parser.getCurrentLocation())
172         << "parsed polynomial must have unique exponents among monomials";
173     return {};
174   }
175   return IntPolynomialAttr::get(parser.getContext(), result.value());
176 }
177 Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
178   if (failed(parser.parseLess()))
179     return {};
180 
181   llvm::SmallVector<FloatMonomial> monomials;
182   llvm::StringSet<> variables;
183 
184   ParseCoefficientFn<FloatMonomial> parseAndStoreCoefficient =
185       [&](FloatMonomial &monomial) -> OptionalParseResult {
186     double coeffValue = 1.0;
187     ParseResult result = parser.parseFloat(coeffValue);
188     monomial.setCoefficient(APFloat(coeffValue));
189     return OptionalParseResult(result);
190   };
191 
192   if (failed(parsePolynomialAttr<FloatMonomial>(parser, monomials, variables,
193                                                 parseAndStoreCoefficient))) {
194     return {};
195   }
196 
197   auto result = FloatPolynomial::fromMonomials(monomials);
198   if (failed(result)) {
199     parser.emitError(parser.getCurrentLocation())
200         << "parsed polynomial must have unique exponents among monomials";
201     return {};
202   }
203   return FloatPolynomialAttr::get(parser.getContext(), result.value());
204 }
205 
206 LogicalResult
207 RingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
208                  Type coefficientType, IntegerAttr coefficientModulus,
209                  IntPolynomialAttr polynomialModulus) {
210   if (coefficientModulus) {
211     auto coeffIntType = llvm::dyn_cast<IntegerType>(coefficientType);
212     if (!coeffIntType) {
213       return emitError() << "coefficientModulus specified but coefficientType "
214                             "is not integral";
215     }
216     APInt coeffModValue = coefficientModulus.getValue();
217     if (coeffModValue == 0) {
218       return emitError() << "coefficientModulus should not be 0";
219     }
220     if (coeffModValue.slt(0)) {
221       return emitError() << "coefficientModulus should be positive";
222     }
223     auto coeffModWidth = (coeffModValue - 1).getActiveBits();
224     auto coeffWidth = coeffIntType.getWidth();
225     if (coeffModWidth > coeffWidth) {
226       return emitError() << "coefficientModulus needs bit width of "
227                          << coeffModWidth
228                          << " but coefficientType can only contain "
229                          << coeffWidth << " bits";
230     }
231   }
232   return success();
233 }
234 
235 } // namespace polynomial
236 } // namespace mlir
237