xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h (revision 34ed07e6a1a1d5218d1d323577b348568377303d)
16b88c852SAart Bik //===- DimLvlMapParser.h - `DimLvlMap` parser -------------------*- C++ -*-===//
26b88c852SAart Bik //
36b88c852SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46b88c852SAart Bik // See https://llvm.org/LICENSE.txt for license information.
56b88c852SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66b88c852SAart Bik //
76b88c852SAart Bik //===----------------------------------------------------------------------===//
86b88c852SAart Bik 
96b88c852SAart Bik #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAPPARSER_H
106b88c852SAart Bik #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAPPARSER_H
116b88c852SAart Bik 
126b88c852SAart Bik #include "DimLvlMap.h"
136b88c852SAart Bik #include "LvlTypeParser.h"
146b88c852SAart Bik 
156b88c852SAart Bik namespace mlir {
166b88c852SAart Bik namespace sparse_tensor {
176b88c852SAart Bik namespace ir_detail {
186b88c852SAart Bik 
19b939c015SAart Bik ///
20b939c015SAart Bik /// Parses the Sparse Tensor Encoding Attribute (STEA).
21b939c015SAart Bik ///
22b939c015SAart Bik /// General syntax is as follows,
23b939c015SAart Bik ///
24b939c015SAart Bik ///   [s0, ...]     // optional forward decl sym-vars
25b939c015SAart Bik ///   {l0, ...}     // optional forward decl lvl-vars
26b939c015SAart Bik ///   (
27b939c015SAart Bik ///     d0 = ...,   // dim-var = dim-exp
28b939c015SAart Bik ///     ...
29b939c015SAart Bik ///   ) -> (
30b939c015SAart Bik ///     l0 = ...,   // lvl-var = lvl-exp
31b939c015SAart Bik ///     ...
32b939c015SAart Bik ///   )
33b939c015SAart Bik ///
34b939c015SAart Bik /// with simplifications when variables are implicit.
35b939c015SAart Bik ///
366b88c852SAart Bik class DimLvlMapParser final {
376b88c852SAart Bik public:
DimLvlMapParser(AsmParser & parser)386b88c852SAart Bik   explicit DimLvlMapParser(AsmParser &parser) : parser(parser) {}
396b88c852SAart Bik 
406b88c852SAart Bik   // Parses the input for a sparse tensor dimension-level map
416b88c852SAart Bik   // and returns the map on success.
426b88c852SAart Bik   FailureOr<DimLvlMap> parseDimLvlMap();
436b88c852SAart Bik 
446b88c852SAart Bik private:
45*34ed07e6SYinying Li   /// Client code should prefer using `parseVarUsage`
46*34ed07e6SYinying Li   /// and `parseVarBinding` rather than calling this method directly.
476b88c852SAart Bik   OptionalParseResult parseVar(VarKind vk, bool isOptional,
48ad7a6b67Swren romano                                Policy creationPolicy, VarInfo::ID &id,
496b88c852SAart Bik                                bool &didCreate);
506b88c852SAart Bik 
51*34ed07e6SYinying Li   /// Parses a variable occurence which is a *use* of that variable.
52*34ed07e6SYinying Li   /// When a valid variable name is currently unused, if
53*34ed07e6SYinying Li   /// `requireKnown=true`, an error is raised; if `requireKnown=false`,
54889f4bf2Swren romano   /// a new unbound variable will be created.
55889f4bf2Swren romano   FailureOr<VarInfo::ID> parseVarUsage(VarKind vk, bool requireKnown);
56889f4bf2Swren romano 
57*34ed07e6SYinying Li   /// Parses a variable occurence which is a *binding* of that variable.
58889f4bf2Swren romano   /// The `requireKnown` parameter is for handling the binding of
59889f4bf2Swren romano   /// forward-declared variables.
60889f4bf2Swren romano   FailureOr<VarInfo::ID> parseVarBinding(VarKind vk, bool requireKnown = false);
61889f4bf2Swren romano 
62*34ed07e6SYinying Li   /// Parses an optional variable binding. When the next token is
63889f4bf2Swren romano   /// not a valid variable name, this will bind a new unnamed variable.
64889f4bf2Swren romano   /// The returned `bool` indicates whether a variable name was parsed.
65889f4bf2Swren romano   FailureOr<std::pair<Var, bool>>
66889f4bf2Swren romano   parseOptionalVarBinding(VarKind vk, bool requireKnown = false);
67889f4bf2Swren romano 
68889f4bf2Swren romano   /// Binds the given variable: both updating the `VarEnv` itself, and
69*34ed07e6SYinying Li   /// the `{dims,lvls}AndSymbols` lists (which will be passed
70889f4bf2Swren romano   /// to `AsmParser::parseAffineExpr`). This method is already called by the
71889f4bf2Swren romano   /// `parseVarBinding`/`parseOptionalVarBinding` methods, therefore should
72889f4bf2Swren romano   /// not need to be called elsewhere.
73889f4bf2Swren romano   Var bindVar(llvm::SMLoc loc, VarInfo::ID id);
74889f4bf2Swren romano 
75889f4bf2Swren romano   ParseResult parseSymbolBindingList();
76889f4bf2Swren romano   ParseResult parseLvlVarBindingList();
776b88c852SAart Bik   ParseResult parseDimSpec();
786b88c852SAart Bik   ParseResult parseDimSpecList();
79889f4bf2Swren romano   FailureOr<LvlVar> parseLvlVarBinding(bool requireLvlVarBinding);
80889f4bf2Swren romano   ParseResult parseLvlSpec(bool requireLvlVarBinding);
816b88c852SAart Bik   ParseResult parseLvlSpecList();
826b88c852SAart Bik 
836b88c852SAart Bik   AsmParser &parser;
846b88c852SAart Bik   LvlTypeParser lvlTypeParser;
856b88c852SAart Bik   VarEnv env;
86889f4bf2Swren romano   // The parser maintains the `{dims,lvls}AndSymbols` lists to avoid
87889f4bf2Swren romano   // the O(n^2) cost of repeatedly constructing them inside of the
88889f4bf2Swren romano   // `parse{Dim,Lvl}Spec` methods.
89889f4bf2Swren romano   SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
90889f4bf2Swren romano   SmallVector<std::pair<StringRef, AffineExpr>, 4> lvlsAndSymbols;
916b88c852SAart Bik   SmallVector<DimSpec> dimSpecs;
926b88c852SAart Bik   SmallVector<LvlSpec> lvlSpecs;
936b88c852SAart Bik };
946b88c852SAart Bik 
956b88c852SAart Bik } // namespace ir_detail
966b88c852SAart Bik } // namespace sparse_tensor
976b88c852SAart Bik } // namespace mlir
986b88c852SAart Bik 
996b88c852SAart Bik #endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAPPARSER_H
100