xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp (revision 1944c4f76b47c0b86c91845987baca24fd4775f8)
1 //===- DimLvlMapParser.cpp - `DimLvlMap` parser implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "DimLvlMapParser.h"
10 
11 using namespace mlir;
12 using namespace mlir::sparse_tensor;
13 using namespace mlir::sparse_tensor::ir_detail;
14 
15 #define FAILURE_IF_FAILED(RES)                                                 \
16   if (failed(RES)) {                                                           \
17     return failure();                                                          \
18   }
19 
20 /// Helper function for `FAILURE_IF_NULLOPT_OR_FAILED` to avoid duplicating
21 /// its `RES` parameter.
didntSucceed(OptionalParseResult res)22 static inline bool didntSucceed(OptionalParseResult res) {
23   return !res.has_value() || failed(*res);
24 }
25 
26 #define FAILURE_IF_NULLOPT_OR_FAILED(RES)                                      \
27   if (didntSucceed(RES)) {                                                     \
28     return failure();                                                          \
29   }
30 
31 // NOTE: this macro assumes `AsmParser parser` and `SMLoc loc` are in scope.
32 #define ERROR_IF(COND, MSG)                                                    \
33   if (COND) {                                                                  \
34     return parser.emitError(loc, MSG);                                         \
35   }
36 
37 //===----------------------------------------------------------------------===//
38 // `DimLvlMapParser` implementation for variable parsing.
39 //===----------------------------------------------------------------------===//
40 
41 // Our variation on `AffineParser::{parseBareIdExpr,parseIdentifierDefinition}`
parseVar(VarKind vk,bool isOptional,Policy creationPolicy,VarInfo::ID & varID,bool & didCreate)42 OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional,
43                                               Policy creationPolicy,
44                                               VarInfo::ID &varID,
45                                               bool &didCreate) {
46   // Save the current location so that we can have error messages point to
47   // the right place.
48   const auto loc = parser.getCurrentLocation();
49   StringRef name;
50   if (failed(parser.parseOptionalKeyword(&name))) {
51     ERROR_IF(!isOptional, "expected bare identifier")
52     return std::nullopt;
53   }
54 
55   if (const auto res = env.lookupOrCreate(creationPolicy, name, loc, vk)) {
56     varID = res->first;
57     didCreate = res->second;
58     return success();
59   }
60 
61   switch (creationPolicy) {
62   case Policy::MustNot:
63     return parser.emitError(loc, "use of undeclared identifier '" + name + "'");
64   case Policy::May:
65     llvm_unreachable("got nullopt for Policy::May");
66   case Policy::Must:
67     return parser.emitError(loc, "redefinition of identifier '" + name + "'");
68   }
69   llvm_unreachable("unknown Policy");
70 }
71 
parseVarUsage(VarKind vk,bool requireKnown)72 FailureOr<VarInfo::ID> DimLvlMapParser::parseVarUsage(VarKind vk,
73                                                       bool requireKnown) {
74   VarInfo::ID id;
75   bool didCreate;
76   const bool isOptional = false;
77   const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::May;
78   const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
79   FAILURE_IF_NULLOPT_OR_FAILED(res)
80   assert(requireKnown ? !didCreate : true);
81   return id;
82 }
83 
parseVarBinding(VarKind vk,bool requireKnown)84 FailureOr<VarInfo::ID> DimLvlMapParser::parseVarBinding(VarKind vk,
85                                                         bool requireKnown) {
86   const auto loc = parser.getCurrentLocation();
87   VarInfo::ID id;
88   bool didCreate;
89   const bool isOptional = false;
90   const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
91   const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
92   FAILURE_IF_NULLOPT_OR_FAILED(res)
93   assert(requireKnown ? !didCreate : didCreate);
94   bindVar(loc, id);
95   return id;
96 }
97 
98 FailureOr<std::pair<Var, bool>>
parseOptionalVarBinding(VarKind vk,bool requireKnown)99 DimLvlMapParser::parseOptionalVarBinding(VarKind vk, bool requireKnown) {
100   const auto loc = parser.getCurrentLocation();
101   VarInfo::ID id;
102   bool didCreate;
103   const bool isOptional = true;
104   const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
105   const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
106   if (res.has_value()) {
107     FAILURE_IF_FAILED(*res)
108     assert(didCreate);
109     return std::make_pair(bindVar(loc, id), true);
110   }
111   assert(!didCreate);
112   return std::make_pair(env.bindUnusedVar(vk), false);
113 }
114 
bindVar(llvm::SMLoc loc,VarInfo::ID id)115 Var DimLvlMapParser::bindVar(llvm::SMLoc loc, VarInfo::ID id) {
116   MLIRContext *context = parser.getContext();
117   const auto var = env.bindVar(id);
118   const auto &info = std::as_const(env).access(id);
119   const auto name = info.getName();
120   const auto num = *info.getNum();
121   switch (info.getKind()) {
122   case VarKind::Symbol: {
123     const auto affine = getAffineSymbolExpr(num, context);
124     dimsAndSymbols.emplace_back(name, affine);
125     lvlsAndSymbols.emplace_back(name, affine);
126     return var;
127   }
128   case VarKind::Dimension:
129     dimsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
130     return var;
131   case VarKind::Level:
132     lvlsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
133     return var;
134   }
135   llvm_unreachable("unknown VarKind");
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // `DimLvlMapParser` implementation for `DimLvlMap` per se.
140 //===----------------------------------------------------------------------===//
141 
parseDimLvlMap()142 FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
143   FAILURE_IF_FAILED(parseSymbolBindingList())
144   FAILURE_IF_FAILED(parseLvlVarBindingList())
145   FAILURE_IF_FAILED(parseDimSpecList())
146   FAILURE_IF_FAILED(parser.parseArrow())
147   FAILURE_IF_FAILED(parseLvlSpecList())
148   InFlightDiagnostic ifd = env.emitErrorIfAnyUnbound(parser);
149   if (failed(ifd))
150     return ifd;
151   return DimLvlMap(env.getRanks().getSymRank(), dimSpecs, lvlSpecs);
152 }
153 
parseSymbolBindingList()154 ParseResult DimLvlMapParser::parseSymbolBindingList() {
155   return parser.parseCommaSeparatedList(
156       OpAsmParser::Delimiter::OptionalSquare,
157       [this]() { return ParseResult(parseVarBinding(VarKind::Symbol)); },
158       " in symbol binding list");
159 }
160 
parseLvlVarBindingList()161 ParseResult DimLvlMapParser::parseLvlVarBindingList() {
162   return parser.parseCommaSeparatedList(
163       OpAsmParser::Delimiter::OptionalBraces,
164       [this]() { return ParseResult(parseVarBinding(VarKind::Level)); },
165       " in level declaration list");
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // `DimLvlMapParser` implementation for `DimSpec`.
170 //===----------------------------------------------------------------------===//
171 
parseDimSpecList()172 ParseResult DimLvlMapParser::parseDimSpecList() {
173   return parser.parseCommaSeparatedList(
174       OpAsmParser::Delimiter::Paren,
175       [this]() -> ParseResult { return parseDimSpec(); },
176       " in dimension-specifier list");
177 }
178 
parseDimSpec()179 ParseResult DimLvlMapParser::parseDimSpec() {
180   // Parse the requisite dim-var binding.
181   const auto varID = parseVarBinding(VarKind::Dimension);
182   FAILURE_IF_FAILED(varID)
183   const DimVar var = env.getVar(*varID).cast<DimVar>();
184 
185   // Parse an optional dimension expression.
186   AffineExpr affine;
187   if (succeeded(parser.parseOptionalEqual())) {
188     // Parse the dim affine expr, with only any lvl-vars in scope.
189     FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine))
190   }
191   DimExpr expr{affine};
192 
193   // Parse an optional slice.
194   SparseTensorDimSliceAttr slice;
195   if (succeeded(parser.parseOptionalColon())) {
196     const auto loc = parser.getCurrentLocation();
197     Attribute attr;
198     FAILURE_IF_FAILED(parser.parseAttribute(attr))
199     slice = llvm::dyn_cast<SparseTensorDimSliceAttr>(attr);
200     ERROR_IF(!slice, "expected SparseTensorDimSliceAttr")
201   }
202 
203   dimSpecs.emplace_back(var, expr, slice);
204   return success();
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // `DimLvlMapParser` implementation for `LvlSpec`.
209 //===----------------------------------------------------------------------===//
210 
parseLvlSpecList()211 ParseResult DimLvlMapParser::parseLvlSpecList() {
212   // This method currently only supports two syntaxes:
213   //
214   // (1) There are no forward-declarations, and no lvl-var bindings:
215   //        (d0, d1) -> (d0 : dense, d1 : compressed)
216   // Therefore `parseLvlVarBindingList` didn't bind any lvl-vars, and thus
217   // `parseLvlSpec` will need to use `VarEnv::bindUnusedVar` to ensure that
218   // the level-rank is correct at the end of parsing.
219   //
220   // (2) There are forward-declarations, and every lvl-spec must have
221   // a lvl-var binding:
222   //    {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
223   // However, this introduces duplicate information since the order of
224   // the lvl-vars in `parseLvlVarBindingList` must agree with their order
225   // in the list of lvl-specs.  Therefore, `parseLvlSpec` will not call
226   // `VarEnv::bindVar` (since `parseLvlVarBindingList` already did so),
227   // and must also validate the consistency between the two lvl-var orders.
228   const auto declaredLvlRank = env.getRanks().getLvlRank();
229   const bool requireLvlVarBinding = declaredLvlRank != 0;
230   // Have `ERROR_IF` point to the start of the list.
231   const auto loc = parser.getCurrentLocation();
232   const auto res = parser.parseCommaSeparatedList(
233       mlir::OpAsmParser::Delimiter::Paren,
234       [=]() -> ParseResult { return parseLvlSpec(requireLvlVarBinding); },
235       " in level-specifier list");
236   FAILURE_IF_FAILED(res)
237   const auto specLvlRank = lvlSpecs.size();
238   ERROR_IF(requireLvlVarBinding && specLvlRank != declaredLvlRank,
239            "Level-rank mismatch between forward-declarations and specifiers. "
240            "Declared " +
241                Twine(declaredLvlRank) + " level-variables; but got " +
242                Twine(specLvlRank) + " level-specifiers.")
243   return success();
244 }
245 
nth(Var::Num n)246 static inline Twine nth(Var::Num n) {
247   switch (n) {
248   case 1:
249     return "1st";
250   case 2:
251     return "2nd";
252   default:
253     return Twine(n) + "th";
254   }
255 }
256 
257 FailureOr<LvlVar>
parseLvlVarBinding(bool requireLvlVarBinding)258 DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
259   // Nothing to parse, just bind an unnamed variable.
260   if (!requireLvlVarBinding)
261     return env.bindUnusedVar(VarKind::Level).cast<LvlVar>();
262 
263   const auto loc = parser.getCurrentLocation();
264   // NOTE: Calling `parseVarUsage` here is semantically inappropriate,
265   // since the thing we're parsing is supposed to be a variable *binding*
266   // rather than a variable *use*.  However, the call to `VarEnv::bindVar`
267   // (and its corresponding call to `DimLvlMapParser::recordVarBinding`)
268   // already occured in `parseLvlVarBindingList`, and therefore we must
269   // use `parseVarUsage` here in order to operationally do the right thing.
270   const auto varID = parseVarUsage(VarKind::Level, /*requireKnown=*/true);
271   FAILURE_IF_FAILED(varID)
272   const auto &info = std::as_const(env).access(*varID);
273   const auto var = info.getVar().cast<LvlVar>();
274   const auto forwardNum = var.getNum();
275   const auto specNum = lvlSpecs.size();
276   ERROR_IF(forwardNum != specNum,
277            "Level-variable ordering mismatch. The variable '" + info.getName() +
278                "' was forward-declared as the " + nth(forwardNum) +
279                " level; but is bound by the " + nth(specNum) +
280                " specification.")
281   FAILURE_IF_FAILED(parser.parseEqual())
282   return var;
283 }
284 
parseLvlSpec(bool requireLvlVarBinding)285 ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
286   // Parse the optional lvl-var binding. `requireLvlVarBinding`
287   // specifies whether that "optional" is actually Must or MustNot.
288   const auto varRes = parseLvlVarBinding(requireLvlVarBinding);
289   FAILURE_IF_FAILED(varRes)
290   const LvlVar var = *varRes;
291 
292   // Parse the lvl affine expr, with only the dim-vars in scope.
293   AffineExpr affine;
294   FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
295   LvlExpr expr{affine};
296 
297   FAILURE_IF_FAILED(parser.parseColon())
298   const auto type = lvlTypeParser.parseLvlType(parser);
299   FAILURE_IF_FAILED(type)
300 
301   lvlSpecs.emplace_back(var, expr, static_cast<LevelType>(*type));
302   return success();
303 }
304 
305 //===----------------------------------------------------------------------===//
306