xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp (revision 1944c4f76b47c0b86c91845987baca24fd4775f8)
16b88c852SAart Bik //===- DimLvlMapParser.cpp - `DimLvlMap` parser implementation ------------===//
26b88c852SAart Bik //
36b88c852SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46b88c852SAart Bik // See https://llvm.org/LICENSE.txt for license information.
56b88c852SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66b88c852SAart Bik //
76b88c852SAart Bik //===----------------------------------------------------------------------===//
86b88c852SAart Bik 
96b88c852SAart Bik #include "DimLvlMapParser.h"
106b88c852SAart Bik 
116b88c852SAart Bik using namespace mlir;
126b88c852SAart Bik using namespace mlir::sparse_tensor;
136b88c852SAart Bik using namespace mlir::sparse_tensor::ir_detail;
146b88c852SAart Bik 
15889f4bf2Swren romano #define FAILURE_IF_FAILED(RES)                                                 \
16889f4bf2Swren romano   if (failed(RES)) {                                                           \
17889f4bf2Swren romano     return failure();                                                          \
18889f4bf2Swren romano   }
19889f4bf2Swren romano 
20889f4bf2Swren romano /// Helper function for `FAILURE_IF_NULLOPT_OR_FAILED` to avoid duplicating
21889f4bf2Swren romano /// its `RES` parameter.
didntSucceed(OptionalParseResult res)22889f4bf2Swren romano static inline bool didntSucceed(OptionalParseResult res) {
23889f4bf2Swren romano   return !res.has_value() || failed(*res);
24889f4bf2Swren romano }
25889f4bf2Swren romano 
26889f4bf2Swren romano #define FAILURE_IF_NULLOPT_OR_FAILED(RES)                                      \
27889f4bf2Swren romano   if (didntSucceed(RES)) {                                                     \
286b88c852SAart Bik     return failure();                                                          \
296b88c852SAart Bik   }
306b88c852SAart Bik 
316b88c852SAart Bik // NOTE: this macro assumes `AsmParser parser` and `SMLoc loc` are in scope.
326b88c852SAart Bik #define ERROR_IF(COND, MSG)                                                    \
336b88c852SAart Bik   if (COND) {                                                                  \
346b88c852SAart Bik     return parser.emitError(loc, MSG);                                         \
356b88c852SAart Bik   }
366b88c852SAart Bik 
376b88c852SAart Bik //===----------------------------------------------------------------------===//
386b88c852SAart Bik // `DimLvlMapParser` implementation for variable parsing.
396b88c852SAart Bik //===----------------------------------------------------------------------===//
406b88c852SAart Bik 
416b88c852SAart Bik // Our variation on `AffineParser::{parseBareIdExpr,parseIdentifierDefinition}`
parseVar(VarKind vk,bool isOptional,Policy creationPolicy,VarInfo::ID & varID,bool & didCreate)426b88c852SAart Bik OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional,
43ad7a6b67Swren romano                                               Policy creationPolicy,
446b88c852SAart Bik                                               VarInfo::ID &varID,
456b88c852SAart Bik                                               bool &didCreate) {
466b88c852SAart Bik   // Save the current location so that we can have error messages point to
4734ed07e6SYinying Li   // the right place.
486b88c852SAart Bik   const auto loc = parser.getCurrentLocation();
496b88c852SAart Bik   StringRef name;
506b88c852SAart Bik   if (failed(parser.parseOptionalKeyword(&name))) {
516b88c852SAart Bik     ERROR_IF(!isOptional, "expected bare identifier")
526b88c852SAart Bik     return std::nullopt;
536b88c852SAart Bik   }
546b88c852SAart Bik 
556b88c852SAart Bik   if (const auto res = env.lookupOrCreate(creationPolicy, name, loc, vk)) {
566b88c852SAart Bik     varID = res->first;
576b88c852SAart Bik     didCreate = res->second;
586b88c852SAart Bik     return success();
596b88c852SAart Bik   }
6034ed07e6SYinying Li 
616b88c852SAart Bik   switch (creationPolicy) {
62ad7a6b67Swren romano   case Policy::MustNot:
636b88c852SAart Bik     return parser.emitError(loc, "use of undeclared identifier '" + name + "'");
64ad7a6b67Swren romano   case Policy::May:
65ad7a6b67Swren romano     llvm_unreachable("got nullopt for Policy::May");
66ad7a6b67Swren romano   case Policy::Must:
676b88c852SAart Bik     return parser.emitError(loc, "redefinition of identifier '" + name + "'");
686b88c852SAart Bik   }
69ad7a6b67Swren romano   llvm_unreachable("unknown Policy");
706b88c852SAart Bik }
716b88c852SAart Bik 
parseVarUsage(VarKind vk,bool requireKnown)72889f4bf2Swren romano FailureOr<VarInfo::ID> DimLvlMapParser::parseVarUsage(VarKind vk,
73889f4bf2Swren romano                                                       bool requireKnown) {
74889f4bf2Swren romano   VarInfo::ID id;
756b88c852SAart Bik   bool didCreate;
76889f4bf2Swren romano   const bool isOptional = false;
77889f4bf2Swren romano   const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::May;
78889f4bf2Swren romano   const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
79889f4bf2Swren romano   FAILURE_IF_NULLOPT_OR_FAILED(res)
80889f4bf2Swren romano   assert(requireKnown ? !didCreate : true);
81889f4bf2Swren romano   return id;
82889f4bf2Swren romano }
83889f4bf2Swren romano 
parseVarBinding(VarKind vk,bool requireKnown)84889f4bf2Swren romano FailureOr<VarInfo::ID> DimLvlMapParser::parseVarBinding(VarKind vk,
85889f4bf2Swren romano                                                         bool requireKnown) {
86889f4bf2Swren romano   const auto loc = parser.getCurrentLocation();
87889f4bf2Swren romano   VarInfo::ID id;
88889f4bf2Swren romano   bool didCreate;
89889f4bf2Swren romano   const bool isOptional = false;
90889f4bf2Swren romano   const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
91889f4bf2Swren romano   const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
92889f4bf2Swren romano   FAILURE_IF_NULLOPT_OR_FAILED(res)
93889f4bf2Swren romano   assert(requireKnown ? !didCreate : didCreate);
94889f4bf2Swren romano   bindVar(loc, id);
95889f4bf2Swren romano   return id;
966b88c852SAart Bik }
976b88c852SAart Bik 
986b88c852SAart Bik FailureOr<std::pair<Var, bool>>
parseOptionalVarBinding(VarKind vk,bool requireKnown)99889f4bf2Swren romano DimLvlMapParser::parseOptionalVarBinding(VarKind vk, bool requireKnown) {
100889f4bf2Swren romano   const auto loc = parser.getCurrentLocation();
1016b88c852SAart Bik   VarInfo::ID id;
1026b88c852SAart Bik   bool didCreate;
103889f4bf2Swren romano   const bool isOptional = true;
104889f4bf2Swren romano   const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
105889f4bf2Swren romano   const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
1066b88c852SAart Bik   if (res.has_value()) {
1076b88c852SAart Bik     FAILURE_IF_FAILED(*res)
108889f4bf2Swren romano     assert(didCreate);
109889f4bf2Swren romano     return std::make_pair(bindVar(loc, id), true);
110b939c015SAart Bik   }
111889f4bf2Swren romano   assert(!didCreate);
1126b88c852SAart Bik   return std::make_pair(env.bindUnusedVar(vk), false);
1136b88c852SAart Bik }
114b939c015SAart Bik 
bindVar(llvm::SMLoc loc,VarInfo::ID id)115889f4bf2Swren romano Var DimLvlMapParser::bindVar(llvm::SMLoc loc, VarInfo::ID id) {
116889f4bf2Swren romano   MLIRContext *context = parser.getContext();
117889f4bf2Swren romano   const auto var = env.bindVar(id);
118889f4bf2Swren romano   const auto &info = std::as_const(env).access(id);
119889f4bf2Swren romano   const auto name = info.getName();
120889f4bf2Swren romano   const auto num = *info.getNum();
121889f4bf2Swren romano   switch (info.getKind()) {
122889f4bf2Swren romano   case VarKind::Symbol: {
123889f4bf2Swren romano     const auto affine = getAffineSymbolExpr(num, context);
124889f4bf2Swren romano     dimsAndSymbols.emplace_back(name, affine);
125889f4bf2Swren romano     lvlsAndSymbols.emplace_back(name, affine);
126889f4bf2Swren romano     return var;
127889f4bf2Swren romano   }
128889f4bf2Swren romano   case VarKind::Dimension:
129889f4bf2Swren romano     dimsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
130889f4bf2Swren romano     return var;
131889f4bf2Swren romano   case VarKind::Level:
132889f4bf2Swren romano     lvlsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
133889f4bf2Swren romano     return var;
134889f4bf2Swren romano   }
135889f4bf2Swren romano   llvm_unreachable("unknown VarKind");
1366b88c852SAart Bik }
1376b88c852SAart Bik 
1386b88c852SAart Bik //===----------------------------------------------------------------------===//
1396b88c852SAart Bik // `DimLvlMapParser` implementation for `DimLvlMap` per se.
1406b88c852SAart Bik //===----------------------------------------------------------------------===//
1416b88c852SAart Bik 
parseDimLvlMap()1426b88c852SAart Bik FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
143889f4bf2Swren romano   FAILURE_IF_FAILED(parseSymbolBindingList())
144889f4bf2Swren romano   FAILURE_IF_FAILED(parseLvlVarBindingList())
1456b88c852SAart Bik   FAILURE_IF_FAILED(parseDimSpecList())
1466b88c852SAart Bik   FAILURE_IF_FAILED(parser.parseArrow())
1476b88c852SAart Bik   FAILURE_IF_FAILED(parseLvlSpecList())
1486b88c852SAart Bik   InFlightDiagnostic ifd = env.emitErrorIfAnyUnbound(parser);
1496b88c852SAart Bik   if (failed(ifd))
1506b88c852SAart Bik     return ifd;
1516b88c852SAart Bik   return DimLvlMap(env.getRanks().getSymRank(), dimSpecs, lvlSpecs);
1526b88c852SAart Bik }
1536b88c852SAart Bik 
parseSymbolBindingList()154889f4bf2Swren romano ParseResult DimLvlMapParser::parseSymbolBindingList() {
155889f4bf2Swren romano   return parser.parseCommaSeparatedList(
156889f4bf2Swren romano       OpAsmParser::Delimiter::OptionalSquare,
157889f4bf2Swren romano       [this]() { return ParseResult(parseVarBinding(VarKind::Symbol)); },
158889f4bf2Swren romano       " in symbol binding list");
159889f4bf2Swren romano }
160889f4bf2Swren romano 
parseLvlVarBindingList()161889f4bf2Swren romano ParseResult DimLvlMapParser::parseLvlVarBindingList() {
162889f4bf2Swren romano   return parser.parseCommaSeparatedList(
163889f4bf2Swren romano       OpAsmParser::Delimiter::OptionalBraces,
164889f4bf2Swren romano       [this]() { return ParseResult(parseVarBinding(VarKind::Level)); },
165889f4bf2Swren romano       " in level declaration list");
1666b88c852SAart Bik }
1676b88c852SAart Bik 
1686b88c852SAart Bik //===----------------------------------------------------------------------===//
1696b88c852SAart Bik // `DimLvlMapParser` implementation for `DimSpec`.
1706b88c852SAart Bik //===----------------------------------------------------------------------===//
1716b88c852SAart Bik 
parseDimSpecList()1726b88c852SAart Bik ParseResult DimLvlMapParser::parseDimSpecList() {
1736b88c852SAart Bik   return parser.parseCommaSeparatedList(
174b939c015SAart Bik       OpAsmParser::Delimiter::Paren,
175889f4bf2Swren romano       [this]() -> ParseResult { return parseDimSpec(); },
1766b88c852SAart Bik       " in dimension-specifier list");
1776b88c852SAart Bik }
1786b88c852SAart Bik 
parseDimSpec()1796b88c852SAart Bik ParseResult DimLvlMapParser::parseDimSpec() {
180889f4bf2Swren romano   // Parse the requisite dim-var binding.
181889f4bf2Swren romano   const auto varID = parseVarBinding(VarKind::Dimension);
182889f4bf2Swren romano   FAILURE_IF_FAILED(varID)
183889f4bf2Swren romano   const DimVar var = env.getVar(*varID).cast<DimVar>();
1846b88c852SAart Bik 
185b939c015SAart Bik   // Parse an optional dimension expression.
186b939c015SAart Bik   AffineExpr affine;
1876b88c852SAart Bik   if (succeeded(parser.parseOptionalEqual())) {
188b939c015SAart Bik     // Parse the dim affine expr, with only any lvl-vars in scope.
189889f4bf2Swren romano     FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine))
1906b88c852SAart Bik   }
191b939c015SAart Bik   DimExpr expr{affine};
1926b88c852SAart Bik 
193b939c015SAart Bik   // Parse an optional slice.
1946b88c852SAart Bik   SparseTensorDimSliceAttr slice;
1956b88c852SAart Bik   if (succeeded(parser.parseOptionalColon())) {
1966b88c852SAart Bik     const auto loc = parser.getCurrentLocation();
1976b88c852SAart Bik     Attribute attr;
1986b88c852SAart Bik     FAILURE_IF_FAILED(parser.parseAttribute(attr))
1996b88c852SAart Bik     slice = llvm::dyn_cast<SparseTensorDimSliceAttr>(attr);
2006b88c852SAart Bik     ERROR_IF(!slice, "expected SparseTensorDimSliceAttr")
2016b88c852SAart Bik   }
2026b88c852SAart Bik 
2036b88c852SAart Bik   dimSpecs.emplace_back(var, expr, slice);
2046b88c852SAart Bik   return success();
2056b88c852SAart Bik }
2066b88c852SAart Bik 
2076b88c852SAart Bik //===----------------------------------------------------------------------===//
2086b88c852SAart Bik // `DimLvlMapParser` implementation for `LvlSpec`.
2096b88c852SAart Bik //===----------------------------------------------------------------------===//
2106b88c852SAart Bik 
parseLvlSpecList()2116b88c852SAart Bik ParseResult DimLvlMapParser::parseLvlSpecList() {
212889f4bf2Swren romano   // This method currently only supports two syntaxes:
213889f4bf2Swren romano   //
214889f4bf2Swren romano   // (1) There are no forward-declarations, and no lvl-var bindings:
215b939c015SAart Bik   //        (d0, d1) -> (d0 : dense, d1 : compressed)
216889f4bf2Swren romano   // Therefore `parseLvlVarBindingList` didn't bind any lvl-vars, and thus
217889f4bf2Swren romano   // `parseLvlSpec` will need to use `VarEnv::bindUnusedVar` to ensure that
218889f4bf2Swren romano   // the level-rank is correct at the end of parsing.
219889f4bf2Swren romano   //
220889f4bf2Swren romano   // (2) There are forward-declarations, and every lvl-spec must have
221889f4bf2Swren romano   // a lvl-var binding:
222b939c015SAart Bik   //    {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
223889f4bf2Swren romano   // However, this introduces duplicate information since the order of
224889f4bf2Swren romano   // the lvl-vars in `parseLvlVarBindingList` must agree with their order
225889f4bf2Swren romano   // in the list of lvl-specs.  Therefore, `parseLvlSpec` will not call
226889f4bf2Swren romano   // `VarEnv::bindVar` (since `parseLvlVarBindingList` already did so),
227889f4bf2Swren romano   // and must also validate the consistency between the two lvl-var orders.
228889f4bf2Swren romano   const auto declaredLvlRank = env.getRanks().getLvlRank();
229889f4bf2Swren romano   const bool requireLvlVarBinding = declaredLvlRank != 0;
230889f4bf2Swren romano   // Have `ERROR_IF` point to the start of the list.
231889f4bf2Swren romano   const auto loc = parser.getCurrentLocation();
232889f4bf2Swren romano   const auto res = parser.parseCommaSeparatedList(
233b939c015SAart Bik       mlir::OpAsmParser::Delimiter::Paren,
234889f4bf2Swren romano       [=]() -> ParseResult { return parseLvlSpec(requireLvlVarBinding); },
2356b88c852SAart Bik       " in level-specifier list");
236889f4bf2Swren romano   FAILURE_IF_FAILED(res)
237889f4bf2Swren romano   const auto specLvlRank = lvlSpecs.size();
238889f4bf2Swren romano   ERROR_IF(requireLvlVarBinding && specLvlRank != declaredLvlRank,
239889f4bf2Swren romano            "Level-rank mismatch between forward-declarations and specifiers. "
240889f4bf2Swren romano            "Declared " +
241889f4bf2Swren romano                Twine(declaredLvlRank) + " level-variables; but got " +
242889f4bf2Swren romano                Twine(specLvlRank) + " level-specifiers.")
243889f4bf2Swren romano   return success();
2446b88c852SAart Bik }
2456b88c852SAart Bik 
nth(Var::Num n)246889f4bf2Swren romano static inline Twine nth(Var::Num n) {
247889f4bf2Swren romano   switch (n) {
248889f4bf2Swren romano   case 1:
249889f4bf2Swren romano     return "1st";
250889f4bf2Swren romano   case 2:
251889f4bf2Swren romano     return "2nd";
252889f4bf2Swren romano   default:
253889f4bf2Swren romano     return Twine(n) + "th";
254889f4bf2Swren romano   }
255889f4bf2Swren romano }
256889f4bf2Swren romano 
257889f4bf2Swren romano FailureOr<LvlVar>
parseLvlVarBinding(bool requireLvlVarBinding)258889f4bf2Swren romano DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
259889f4bf2Swren romano   // Nothing to parse, just bind an unnamed variable.
260889f4bf2Swren romano   if (!requireLvlVarBinding)
261889f4bf2Swren romano     return env.bindUnusedVar(VarKind::Level).cast<LvlVar>();
262889f4bf2Swren romano 
263889f4bf2Swren romano   const auto loc = parser.getCurrentLocation();
264889f4bf2Swren romano   // NOTE: Calling `parseVarUsage` here is semantically inappropriate,
265889f4bf2Swren romano   // since the thing we're parsing is supposed to be a variable *binding*
266889f4bf2Swren romano   // rather than a variable *use*.  However, the call to `VarEnv::bindVar`
267889f4bf2Swren romano   // (and its corresponding call to `DimLvlMapParser::recordVarBinding`)
268889f4bf2Swren romano   // already occured in `parseLvlVarBindingList`, and therefore we must
269889f4bf2Swren romano   // use `parseVarUsage` here in order to operationally do the right thing.
270889f4bf2Swren romano   const auto varID = parseVarUsage(VarKind::Level, /*requireKnown=*/true);
271889f4bf2Swren romano   FAILURE_IF_FAILED(varID)
272889f4bf2Swren romano   const auto &info = std::as_const(env).access(*varID);
273889f4bf2Swren romano   const auto var = info.getVar().cast<LvlVar>();
274889f4bf2Swren romano   const auto forwardNum = var.getNum();
275889f4bf2Swren romano   const auto specNum = lvlSpecs.size();
276889f4bf2Swren romano   ERROR_IF(forwardNum != specNum,
277889f4bf2Swren romano            "Level-variable ordering mismatch. The variable '" + info.getName() +
278889f4bf2Swren romano                "' was forward-declared as the " + nth(forwardNum) +
279889f4bf2Swren romano                " level; but is bound by the " + nth(specNum) +
280889f4bf2Swren romano                " specification.")
281889f4bf2Swren romano   FAILURE_IF_FAILED(parser.parseEqual())
282889f4bf2Swren romano   return var;
283889f4bf2Swren romano }
284889f4bf2Swren romano 
parseLvlSpec(bool requireLvlVarBinding)285889f4bf2Swren romano ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
28634ed07e6SYinying Li   // Parse the optional lvl-var binding. `requireLvlVarBinding`
28734ed07e6SYinying Li   // specifies whether that "optional" is actually Must or MustNot.
288889f4bf2Swren romano   const auto varRes = parseLvlVarBinding(requireLvlVarBinding);
289889f4bf2Swren romano   FAILURE_IF_FAILED(varRes)
290889f4bf2Swren romano   const LvlVar var = *varRes;
2916b88c852SAart Bik 
292b939c015SAart Bik   // Parse the lvl affine expr, with only the dim-vars in scope.
293b939c015SAart Bik   AffineExpr affine;
294b939c015SAart Bik   FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
295b939c015SAart Bik   LvlExpr expr{affine};
2966b88c852SAart Bik 
2976b88c852SAart Bik   FAILURE_IF_FAILED(parser.parseColon())
2986b88c852SAart Bik   const auto type = lvlTypeParser.parseLvlType(parser);
2996b88c852SAart Bik   FAILURE_IF_FAILED(type)
3006b88c852SAart Bik 
301*1944c4f7SAart Bik   lvlSpecs.emplace_back(var, expr, static_cast<LevelType>(*type));
3026b88c852SAart Bik   return success();
3036b88c852SAart Bik }
3046b88c852SAart Bik 
3056b88c852SAart Bik //===----------------------------------------------------------------------===//
306