xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h (revision 34ed07e6a1a1d5218d1d323577b348568377303d)
1 //===- DimLvlMapParser.h - `DimLvlMap` parser -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAPPARSER_H
10 #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAPPARSER_H
11 
12 #include "DimLvlMap.h"
13 #include "LvlTypeParser.h"
14 
15 namespace mlir {
16 namespace sparse_tensor {
17 namespace ir_detail {
18 
19 ///
20 /// Parses the Sparse Tensor Encoding Attribute (STEA).
21 ///
22 /// General syntax is as follows,
23 ///
24 ///   [s0, ...]     // optional forward decl sym-vars
25 ///   {l0, ...}     // optional forward decl lvl-vars
26 ///   (
27 ///     d0 = ...,   // dim-var = dim-exp
28 ///     ...
29 ///   ) -> (
30 ///     l0 = ...,   // lvl-var = lvl-exp
31 ///     ...
32 ///   )
33 ///
34 /// with simplifications when variables are implicit.
35 ///
36 class DimLvlMapParser final {
37 public:
DimLvlMapParser(AsmParser & parser)38   explicit DimLvlMapParser(AsmParser &parser) : parser(parser) {}
39 
40   // Parses the input for a sparse tensor dimension-level map
41   // and returns the map on success.
42   FailureOr<DimLvlMap> parseDimLvlMap();
43 
44 private:
45   /// Client code should prefer using `parseVarUsage`
46   /// and `parseVarBinding` rather than calling this method directly.
47   OptionalParseResult parseVar(VarKind vk, bool isOptional,
48                                Policy creationPolicy, VarInfo::ID &id,
49                                bool &didCreate);
50 
51   /// Parses a variable occurence which is a *use* of that variable.
52   /// When a valid variable name is currently unused, if
53   /// `requireKnown=true`, an error is raised; if `requireKnown=false`,
54   /// a new unbound variable will be created.
55   FailureOr<VarInfo::ID> parseVarUsage(VarKind vk, bool requireKnown);
56 
57   /// Parses a variable occurence which is a *binding* of that variable.
58   /// The `requireKnown` parameter is for handling the binding of
59   /// forward-declared variables.
60   FailureOr<VarInfo::ID> parseVarBinding(VarKind vk, bool requireKnown = false);
61 
62   /// Parses an optional variable binding. When the next token is
63   /// not a valid variable name, this will bind a new unnamed variable.
64   /// The returned `bool` indicates whether a variable name was parsed.
65   FailureOr<std::pair<Var, bool>>
66   parseOptionalVarBinding(VarKind vk, bool requireKnown = false);
67 
68   /// Binds the given variable: both updating the `VarEnv` itself, and
69   /// the `{dims,lvls}AndSymbols` lists (which will be passed
70   /// to `AsmParser::parseAffineExpr`). This method is already called by the
71   /// `parseVarBinding`/`parseOptionalVarBinding` methods, therefore should
72   /// not need to be called elsewhere.
73   Var bindVar(llvm::SMLoc loc, VarInfo::ID id);
74 
75   ParseResult parseSymbolBindingList();
76   ParseResult parseLvlVarBindingList();
77   ParseResult parseDimSpec();
78   ParseResult parseDimSpecList();
79   FailureOr<LvlVar> parseLvlVarBinding(bool requireLvlVarBinding);
80   ParseResult parseLvlSpec(bool requireLvlVarBinding);
81   ParseResult parseLvlSpecList();
82 
83   AsmParser &parser;
84   LvlTypeParser lvlTypeParser;
85   VarEnv env;
86   // The parser maintains the `{dims,lvls}AndSymbols` lists to avoid
87   // the O(n^2) cost of repeatedly constructing them inside of the
88   // `parse{Dim,Lvl}Spec` methods.
89   SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
90   SmallVector<std::pair<StringRef, AffineExpr>, 4> lvlsAndSymbols;
91   SmallVector<DimSpec> dimSpecs;
92   SmallVector<LvlSpec> lvlSpecs;
93 };
94 
95 } // namespace ir_detail
96 } // namespace sparse_tensor
97 } // namespace mlir
98 
99 #endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAPPARSER_H
100