xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h (revision 704c22473641e26d95435c55aa482fbf5abbbc2c)
16b88c852SAart Bik //===- Var.h ----------------------------------------------------*- C++ -*-===//
26b88c852SAart Bik //
36b88c852SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46b88c852SAart Bik // See https://llvm.org/LICENSE.txt for license information.
56b88c852SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66b88c852SAart Bik //
76b88c852SAart Bik //===----------------------------------------------------------------------===//
86b88c852SAart Bik 
96b88c852SAart Bik #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
106b88c852SAart Bik #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
116b88c852SAart Bik 
126b88c852SAart Bik #include "TemplateExtras.h"
136b88c852SAart Bik 
146b88c852SAart Bik #include "mlir/IR/OpImplementation.h"
156b88c852SAart Bik #include "llvm/ADT/EnumeratedArray.h"
167c2ef38cSVlad Serebrennikov #include "llvm/ADT/STLForwardCompat.h"
176b88c852SAart Bik #include "llvm/ADT/SmallBitVector.h"
186b88c852SAart Bik #include "llvm/ADT/StringMap.h"
196b88c852SAart Bik 
206b88c852SAart Bik namespace mlir {
216b88c852SAart Bik namespace sparse_tensor {
226b88c852SAart Bik namespace ir_detail {
236b88c852SAart Bik 
246b88c852SAart Bik //===----------------------------------------------------------------------===//
256b88c852SAart Bik /// The three kinds of variables that `Var` can be.
266b88c852SAart Bik ///
276b88c852SAart Bik /// NOTE: The numerical values used to represent this enum should be
286b88c852SAart Bik /// treated as an implementation detail, not as part of the API.  In the
296b88c852SAart Bik /// API below we use the canonical ordering `{Symbol,Dimension,Level}` even
306b88c852SAart Bik /// though that does not agree with the numerical ordering of the numerical
316b88c852SAart Bik /// representation.
326b88c852SAart Bik enum class VarKind { Symbol = 1, Dimension = 0, Level = 2 };
336b88c852SAart Bik 
isWF(VarKind vk)342a682138Swren romano [[nodiscard]] constexpr bool isWF(VarKind vk) {
357c2ef38cSVlad Serebrennikov   const auto vk_ = llvm::to_underlying(vk);
366b88c852SAart Bik   return 0 <= vk_ && vk_ <= 2;
376b88c852SAart Bik }
386b88c852SAart Bik 
396b88c852SAart Bik /// Gets the ASCII character used as the prefix when printing `Var`.
toChar(VarKind vk)406b88c852SAart Bik constexpr char toChar(VarKind vk) {
416b88c852SAart Bik   // If `isWF(vk)` then this computation's intermediate results are always
426b88c852SAart Bik   // in the range [-44..126] (where that lower bound is under worst-case
436b88c852SAart Bik   // rearranging of the expression); and `int_fast8_t` is the fastest type
446b88c852SAart Bik   // which can support that range without over-/underflow.
457c2ef38cSVlad Serebrennikov   const auto vk_ = static_cast<int_fast8_t>(llvm::to_underlying(vk));
466b88c852SAart Bik   return static_cast<char>(100 + vk_ * (26 - vk_ * 11));
476b88c852SAart Bik }
486b88c852SAart Bik static_assert(toChar(VarKind::Symbol) == 's' &&
496b88c852SAart Bik               toChar(VarKind::Dimension) == 'd' &&
506b88c852SAart Bik               toChar(VarKind::Level) == 'l');
516b88c852SAart Bik 
526b88c852SAart Bik //===----------------------------------------------------------------------===//
536b88c852SAart Bik /// The type of arrays indexed by `VarKind`.
546b88c852SAart Bik template <typename T>
556b88c852SAart Bik using VarKindArray = llvm::EnumeratedArray<T, VarKind, VarKind::Level>;
566b88c852SAart Bik 
576b88c852SAart Bik //===----------------------------------------------------------------------===//
586b88c852SAart Bik /// A concrete variable, to be used in our variant of `AffineExpr`.
5934ed07e6SYinying Li /// Client-facing class for `VarKind` + `Var::Num` pairs, with RTTI
6034ed07e6SYinying Li /// support for subclasses with a fixed `VarKind`.
616b88c852SAart Bik class Var {
626b88c852SAart Bik public:
6368785c1cSwren romano   /// Typedef for the type of variable numbers.
646b88c852SAart Bik   using Num = unsigned;
656b88c852SAart Bik 
666b88c852SAart Bik private:
6768785c1cSwren romano   /// Typedef for the underlying storage of `Var::Impl`.
6868785c1cSwren romano   using Storage = unsigned;
696b88c852SAart Bik 
7068785c1cSwren romano   /// The largest `Var::Num` supported by `Var`/`Var::Impl`/`Var::Storage`.
7168785c1cSwren romano   /// Two low-order bits are reserved for storing the `VarKind`,
7268785c1cSwren romano   /// and one high-order bit is reserved for future use (e.g., to support
7368785c1cSwren romano   /// `DenseMapInfo<Var>` while maintaining the usual numeric values for
7468785c1cSwren romano   /// "empty" and "tombstone").
756b88c852SAart Bik   static constexpr Num kMaxNum =
7668785c1cSwren romano       static_cast<Num>(std::numeric_limits<Storage>::max() >> 3);
776b88c852SAart Bik 
786b88c852SAart Bik public:
7968785c1cSwren romano   /// Checks whether the number would be accepted by `Var(VarKind,Var::Num)`.
8068785c1cSwren romano   //
816b88c852SAart Bik   // This must be public for `VarInfo` to use it (whereas we don't want
826b88c852SAart Bik   // to expose the `impl` field via friendship).
isWF_Num(Num n)832a682138Swren romano   [[nodiscard]] static constexpr bool isWF_Num(Num n) { return n <= kMaxNum; }
846b88c852SAart Bik 
8568785c1cSwren romano protected:
8668785c1cSwren romano   /// The underlying implementation of `Var`.  Note that this must be kept
8768785c1cSwren romano   /// distinct from `Var` itself, since we want to ensure that the RTTI
8868785c1cSwren romano   /// methods will select the `U(Var::Impl)` ctor rather than selecting
8968785c1cSwren romano   /// the `U(Var::Num)` ctor.
9068785c1cSwren romano   class Impl final {
9168785c1cSwren romano     Storage data;
9268785c1cSwren romano 
9368785c1cSwren romano   public:
Impl(VarKind vk,Num n)9468785c1cSwren romano     constexpr Impl(VarKind vk, Num n)
9568785c1cSwren romano         : data((static_cast<Storage>(n) << 2) |
967c2ef38cSVlad Serebrennikov                static_cast<Storage>(llvm::to_underlying(vk))) {
976b88c852SAart Bik       assert(isWF(vk) && "unknown VarKind");
986b88c852SAart Bik       assert(isWF_Num(n) && "Var::Num is too large");
996b88c852SAart Bik     }
10068785c1cSwren romano     constexpr bool operator==(Impl other) const { return data == other.data; }
10168785c1cSwren romano     constexpr bool operator!=(Impl other) const { return !(*this == other); }
getKind()10268785c1cSwren romano     constexpr VarKind getKind() const { return static_cast<VarKind>(data & 3); }
getNum()10368785c1cSwren romano     constexpr Num getNum() const { return static_cast<Num>(data >> 2); }
10468785c1cSwren romano   };
10568785c1cSwren romano   static_assert(IsZeroCostAbstraction<Impl>);
10668785c1cSwren romano 
10768785c1cSwren romano private:
10868785c1cSwren romano   Impl impl;
10968785c1cSwren romano 
11068785c1cSwren romano protected:
11168785c1cSwren romano   /// Protected ctor for the RTTI methods to use.
Var(Impl impl)11268785c1cSwren romano   constexpr explicit Var(Impl impl) : impl(impl) {}
11368785c1cSwren romano 
11468785c1cSwren romano public:
Var(VarKind vk,Num n)11568785c1cSwren romano   constexpr Var(VarKind vk, Num n) : impl(Impl(vk, n)) {}
Var(AffineSymbolExpr sym)1166b88c852SAart Bik   Var(AffineSymbolExpr sym) : Var(VarKind::Symbol, sym.getPosition()) {}
Var(VarKind vk,AffineDimExpr var)1173b00f448Swren romano   Var(VarKind vk, AffineDimExpr var) : Var(vk, var.getPosition()) {
1183b00f448Swren romano     assert(vk != VarKind::Symbol);
1193b00f448Swren romano   }
1206b88c852SAart Bik 
1216b88c852SAart Bik   constexpr bool operator==(Var other) const { return impl == other.impl; }
1226b88c852SAart Bik   constexpr bool operator!=(Var other) const { return !(*this == other); }
1236b88c852SAart Bik 
getKind()12468785c1cSwren romano   constexpr VarKind getKind() const { return impl.getKind(); }
getNum()12568785c1cSwren romano   constexpr Num getNum() const { return impl.getNum(); }
1266b88c852SAart Bik 
1276b88c852SAart Bik   template <typename U>
1286b88c852SAart Bik   constexpr bool isa() const;
1296b88c852SAart Bik   template <typename U>
1306b88c852SAart Bik   constexpr U cast() const;
1316b88c852SAart Bik   template <typename U>
13268785c1cSwren romano   constexpr std::optional<U> dyn_cast() const;
1336b88c852SAart Bik 
134f5b974b7Swren romano   std::string str() const;
1356b88c852SAart Bik   void print(llvm::raw_ostream &os) const;
1366b88c852SAart Bik   void print(AsmPrinter &printer) const;
1376b88c852SAart Bik   void dump() const;
1386b88c852SAart Bik };
1396b88c852SAart Bik static_assert(IsZeroCostAbstraction<Var>);
1406b88c852SAart Bik 
1416b88c852SAart Bik class SymVar final : public Var {
14268785c1cSwren romano   using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
1436b88c852SAart Bik public:
1446b88c852SAart Bik   static constexpr VarKind Kind = VarKind::Symbol;
classof(Var const * var)1456b88c852SAart Bik   static constexpr bool classof(Var const *var) {
1466b88c852SAart Bik     return var->getKind() == Kind;
1476b88c852SAart Bik   }
SymVar(Num sym)1486b88c852SAart Bik   constexpr SymVar(Num sym) : Var(Kind, sym) {}
SymVar(AffineSymbolExpr symExpr)1496b88c852SAart Bik   SymVar(AffineSymbolExpr symExpr) : Var(symExpr) {}
1506b88c852SAart Bik };
1516b88c852SAart Bik static_assert(IsZeroCostAbstraction<SymVar>);
1526b88c852SAart Bik 
1536b88c852SAart Bik class DimVar final : public Var {
15468785c1cSwren romano   using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
1556b88c852SAart Bik public:
1566b88c852SAart Bik   static constexpr VarKind Kind = VarKind::Dimension;
classof(Var const * var)1576b88c852SAart Bik   static constexpr bool classof(Var const *var) {
1586b88c852SAart Bik     return var->getKind() == Kind;
1596b88c852SAart Bik   }
DimVar(Num dim)1606b88c852SAart Bik   constexpr DimVar(Num dim) : Var(Kind, dim) {}
DimVar(AffineDimExpr dimExpr)1616b88c852SAart Bik   DimVar(AffineDimExpr dimExpr) : Var(Kind, dimExpr) {}
1626b88c852SAart Bik };
1636b88c852SAart Bik static_assert(IsZeroCostAbstraction<DimVar>);
1646b88c852SAart Bik 
1656b88c852SAart Bik class LvlVar final : public Var {
16668785c1cSwren romano   using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
1676b88c852SAart Bik public:
1686b88c852SAart Bik   static constexpr VarKind Kind = VarKind::Level;
classof(Var const * var)1696b88c852SAart Bik   static constexpr bool classof(Var const *var) {
1706b88c852SAart Bik     return var->getKind() == Kind;
1716b88c852SAart Bik   }
LvlVar(Num lvl)1726b88c852SAart Bik   constexpr LvlVar(Num lvl) : Var(Kind, lvl) {}
LvlVar(AffineDimExpr lvlExpr)1736b88c852SAart Bik   LvlVar(AffineDimExpr lvlExpr) : Var(Kind, lvlExpr) {}
1746b88c852SAart Bik };
1756b88c852SAart Bik static_assert(IsZeroCostAbstraction<LvlVar>);
1766b88c852SAart Bik 
1776b88c852SAart Bik template <typename U>
isa()1786b88c852SAart Bik constexpr bool Var::isa() const {
1796b88c852SAart Bik   if constexpr (std::is_same_v<U, SymVar>)
1806b88c852SAart Bik     return getKind() == VarKind::Symbol;
1816b88c852SAart Bik   if constexpr (std::is_same_v<U, DimVar>)
1826b88c852SAart Bik     return getKind() == VarKind::Dimension;
1836b88c852SAart Bik   if constexpr (std::is_same_v<U, LvlVar>)
1846b88c852SAart Bik     return getKind() == VarKind::Level;
1856b88c852SAart Bik }
1866b88c852SAart Bik 
1876b88c852SAart Bik template <typename U>
cast()1886b88c852SAart Bik constexpr U Var::cast() const {
1896b88c852SAart Bik   assert(isa<U>());
19068785c1cSwren romano   // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
19168785c1cSwren romano   return U(impl);
1926b88c852SAart Bik }
1936b88c852SAart Bik 
1946b88c852SAart Bik template <typename U>
dyn_cast()19568785c1cSwren romano constexpr std::optional<U> Var::dyn_cast() const {
19668785c1cSwren romano   // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
19768785c1cSwren romano   return isa<U>() ? std::make_optional(U(impl)) : std::nullopt;
1986b88c852SAart Bik }
1996b88c852SAart Bik 
2006b88c852SAart Bik //===----------------------------------------------------------------------===//
2016b88c852SAart Bik // Forward-decl so that we can declare methods of `Ranks` and `VarSet`.
2026b88c852SAart Bik class DimLvlExpr;
2036b88c852SAart Bik 
2046b88c852SAart Bik //===----------------------------------------------------------------------===//
2056b88c852SAart Bik class Ranks final {
2066b88c852SAart Bik   // Not using `VarKindArray` since `EnumeratedArray` doesn't support constexpr.
2076b88c852SAart Bik   unsigned impl[3];
2086b88c852SAart Bik 
to_index(VarKind vk)2096b88c852SAart Bik   static constexpr unsigned to_index(VarKind vk) {
2106b88c852SAart Bik     assert(isWF(vk) && "unknown VarKind");
2117c2ef38cSVlad Serebrennikov     return static_cast<unsigned>(llvm::to_underlying(vk));
2126b88c852SAart Bik   }
2136b88c852SAart Bik 
2146b88c852SAart Bik public:
Ranks(unsigned symRank,unsigned dimRank,unsigned lvlRank)2156b88c852SAart Bik   constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank)
2166b88c852SAart Bik       : impl() {
2176b88c852SAart Bik     impl[to_index(VarKind::Symbol)] = symRank;
2186b88c852SAart Bik     impl[to_index(VarKind::Dimension)] = dimRank;
2196b88c852SAart Bik     impl[to_index(VarKind::Level)] = lvlRank;
2206b88c852SAart Bik   }
Ranks(VarKindArray<unsigned> const & ranks)2216b88c852SAart Bik   Ranks(VarKindArray<unsigned> const &ranks)
2226b88c852SAart Bik       : Ranks(ranks[VarKind::Symbol], ranks[VarKind::Dimension],
2236b88c852SAart Bik               ranks[VarKind::Level]) {}
2246b88c852SAart Bik 
2255df63ad8Swren romano   bool operator==(Ranks const &other) const;
2265df63ad8Swren romano   bool operator!=(Ranks const &other) const { return !(*this == other); }
2275df63ad8Swren romano 
getRank(VarKind vk)2286b88c852SAart Bik   constexpr unsigned getRank(VarKind vk) const { return impl[to_index(vk)]; }
getSymRank()2296b88c852SAart Bik   constexpr unsigned getSymRank() const { return getRank(VarKind::Symbol); }
getDimRank()2306b88c852SAart Bik   constexpr unsigned getDimRank() const { return getRank(VarKind::Dimension); }
getLvlRank()2316b88c852SAart Bik   constexpr unsigned getLvlRank() const { return getRank(VarKind::Level); }
2326b88c852SAart Bik 
isValid(Var var)2332a682138Swren romano   [[nodiscard]] constexpr bool isValid(Var var) const {
2346b88c852SAart Bik     return var.getNum() < getRank(var.getKind());
2356b88c852SAart Bik   }
2362a682138Swren romano   [[nodiscard]] bool isValid(DimLvlExpr expr) const;
2376b88c852SAart Bik };
2386b88c852SAart Bik static_assert(IsZeroCostAbstraction<Ranks>);
2396b88c852SAart Bik 
2406b88c852SAart Bik //===----------------------------------------------------------------------===//
241dcadb68aSwren romano /// Efficient representation of a set of `Var`.
2426b88c852SAart Bik class VarSet final {
2436b88c852SAart Bik   VarKindArray<llvm::SmallBitVector> impl;
2446b88c852SAart Bik 
2456b88c852SAart Bik public:
2466b88c852SAart Bik   explicit VarSet(Ranks const &ranks);
2476b88c852SAart Bik 
getRank(VarKind vk)2485df63ad8Swren romano   unsigned getRank(VarKind vk) const { return impl[vk].size(); }
getSymRank()2495df63ad8Swren romano   unsigned getSymRank() const { return getRank(VarKind::Symbol); }
getDimRank()2505df63ad8Swren romano   unsigned getDimRank() const { return getRank(VarKind::Dimension); }
getLvlRank()2515df63ad8Swren romano   unsigned getLvlRank() const { return getRank(VarKind::Level); }
getRanks()2525df63ad8Swren romano   Ranks getRanks() const {
2535df63ad8Swren romano     return Ranks(getSymRank(), getDimRank(), getLvlRank());
2545df63ad8Swren romano   }
255*704c2247SYinying Li   /// For the `contains` method: if variables occurring in
25634ed07e6SYinying Li   /// the method parameter are OOB for the `VarSet`, then these methods will
25734ed07e6SYinying Li   /// always return false.
2586b88c852SAart Bik   bool contains(Var var) const;
2596b88c852SAart Bik 
26034ed07e6SYinying Li   /// For the `add` methods: OOB parameters cause undefined behavior.
26134ed07e6SYinying Li   /// Currently the `add` methods will raise an assertion error.
2626b88c852SAart Bik   void add(Var var);
263dcadb68aSwren romano   void add(VarSet const &vars);
2646b88c852SAart Bik   void add(DimLvlExpr expr);
2656b88c852SAart Bik };
2666b88c852SAart Bik 
2676b88c852SAart Bik //===----------------------------------------------------------------------===//
2686b88c852SAart Bik /// A record of metadata for/about a variable, used by `VarEnv`.
2696b88c852SAart Bik /// The principal goal of this record is to enable `VarEnv` to be used for
2706b88c852SAart Bik /// incremental parsing; in particular, `VarInfo` allows the `Var::Num` to
2716b88c852SAart Bik /// remain unknown, since each record is instead identified by `VarInfo::ID`.
2726b88c852SAart Bik /// Therefore the `VarEnv` can freely allocate `VarInfo::ID` in whatever
2736b88c852SAart Bik /// order it likes, irrespective of the binding order (`Var::Num`) of the
2746b88c852SAart Bik /// associated variable.
2756b88c852SAart Bik class VarInfo final {
2766b88c852SAart Bik public:
2776b88c852SAart Bik   /// Newtype for unique identifiers of `VarInfo` records, to ensure
2786b88c852SAart Bik   /// they aren't confused with `Var::Num`.
2796b88c852SAart Bik   enum class ID : unsigned {};
2806b88c852SAart Bik 
2816b88c852SAart Bik private:
2826b88c852SAart Bik   StringRef name;              // The bare-id used in the MLIR source.
2836b88c852SAart Bik   llvm::SMLoc loc;             // The location of the first occurence.
2846b88c852SAart Bik   ID id;                       // The unique `VarInfo`-identifier.
2856b88c852SAart Bik   std::optional<Var::Num> num; // The unique `Var`-identifier (if resolved).
2866b88c852SAart Bik   VarKind kind;                // The kind of variable.
2876b88c852SAart Bik 
2886b88c852SAart Bik public:
2896b88c852SAart Bik   constexpr VarInfo(ID id, StringRef name, llvm::SMLoc loc, VarKind vk,
2906b88c852SAart Bik                     std::optional<Var::Num> n = {})
name(name)2916b88c852SAart Bik       : name(name), loc(loc), id(id), num(n), kind(vk) {
2926b88c852SAart Bik     assert(!name.empty() && "null StringRef");
29347cf7a4bSwren romano     assert(loc.isValid() && "null SMLoc");
2946b88c852SAart Bik     assert(isWF(vk) && "unknown VarKind");
2956b88c852SAart Bik     assert((!n || Var::isWF_Num(*n)) && "Var::Num is too large");
2966b88c852SAart Bik   }
2976b88c852SAart Bik 
getName()2986b88c852SAart Bik   constexpr StringRef getName() const { return name; }
getLoc()2996b88c852SAart Bik   constexpr llvm::SMLoc getLoc() const { return loc; }
getLocation(AsmParser & parser)3006b88c852SAart Bik   Location getLocation(AsmParser &parser) const {
3016b88c852SAart Bik     return parser.getEncodedSourceLoc(loc);
3026b88c852SAart Bik   }
getID()3036b88c852SAart Bik   constexpr ID getID() const { return id; }
getKind()3046b88c852SAart Bik   constexpr VarKind getKind() const { return kind; }
getNum()3056b88c852SAart Bik   constexpr std::optional<Var::Num> getNum() const { return num; }
hasNum()3066b88c852SAart Bik   constexpr bool hasNum() const { return num.has_value(); }
3076b88c852SAart Bik   void setNum(Var::Num n);
getVar()3086b88c852SAart Bik   constexpr Var getVar() const {
3096b88c852SAart Bik     assert(hasNum());
3106b88c852SAart Bik     return Var(kind, *num);
3116b88c852SAart Bik   }
3126b88c852SAart Bik };
3136b88c852SAart Bik 
3146b88c852SAart Bik //===----------------------------------------------------------------------===//
315ad7a6b67Swren romano enum class Policy { MustNot, May, Must };
3166b88c852SAart Bik 
317ad7a6b67Swren romano //===----------------------------------------------------------------------===//
3186b88c852SAart Bik class VarEnv final {
3196b88c852SAart Bik   /// Map from `VarKind` to the next free `Var::Num`; used by `bindVar`.
3206b88c852SAart Bik   VarKindArray<Var::Num> nextNum;
3216b88c852SAart Bik   /// Map from `VarInfo::ID` to shared storage for the actual `VarInfo` objects.
3226b88c852SAart Bik   SmallVector<VarInfo> vars;
3236b88c852SAart Bik   /// Map from variable names to their `VarInfo::ID`.
3246b88c852SAart Bik   llvm::StringMap<VarInfo::ID> ids;
3256b88c852SAart Bik 
nextID()3266b88c852SAart Bik   VarInfo::ID nextID() const { return static_cast<VarInfo::ID>(vars.size()); }
3276b88c852SAart Bik 
3286b88c852SAart Bik public:
VarEnv()3296b88c852SAart Bik   VarEnv() : nextNum(0) {}
3306b88c852SAart Bik 
3316b88c852SAart Bik   /// Gets the underlying storage for the `VarInfo` identified by
3326b88c852SAart Bik   /// the `VarInfo::ID`.
3336b88c852SAart Bik   ///
3346b88c852SAart Bik   /// NOTE: The returned reference can become dangling if the `VarEnv`
3356b88c852SAart Bik   /// object is mutated during the lifetime of the pointer.  Therefore,
3366b88c852SAart Bik   /// client code should not store the reference nor otherwise allow it
3376b88c852SAart Bik   /// to live too long.
access(VarInfo::ID id)3386b88c852SAart Bik   VarInfo const &access(VarInfo::ID id) const {
3396b88c852SAart Bik     // `SmallVector::operator[]` already asserts the index is in-bounds.
3407c2ef38cSVlad Serebrennikov     return vars[llvm::to_underlying(id)];
3416b88c852SAart Bik   }
access(std::optional<VarInfo::ID> oid)3426b88c852SAart Bik   VarInfo const *access(std::optional<VarInfo::ID> oid) const {
3436b88c852SAart Bik     return oid ? &access(*oid) : nullptr;
3446b88c852SAart Bik   }
3456b88c852SAart Bik 
3466b88c852SAart Bik private:
access(VarInfo::ID id)3476b88c852SAart Bik   VarInfo &access(VarInfo::ID id) {
3486b88c852SAart Bik     return const_cast<VarInfo &>(std::as_const(*this).access(id));
3496b88c852SAart Bik   }
access(std::optional<VarInfo::ID> oid)3506b88c852SAart Bik   VarInfo *access(std::optional<VarInfo::ID> oid) {
3516b88c852SAart Bik     return const_cast<VarInfo *>(std::as_const(*this).access(oid));
3526b88c852SAart Bik   }
3536b88c852SAart Bik 
3546b88c852SAart Bik public:
35534ed07e6SYinying Li   /// Looks up the variable with the given name.
3566b88c852SAart Bik   std::optional<VarInfo::ID> lookup(StringRef name) const;
3576b88c852SAart Bik 
35834ed07e6SYinying Li   /// Creates a new currently-unbound variable.  When a variable
3596b88c852SAart Bik   /// of that name already exists: if `verifyUsage` is true, then will assert
3606b88c852SAart Bik   /// that the variable has the same kind and a consistent location; otherwise,
3616b88c852SAart Bik   /// when `verifyUsage` is false, this is a noop.  Returns the identifier
36234ed07e6SYinying Li   /// for the variable with the given name, and a bool indicating whether
3636b88c852SAart Bik   /// a new variable was created.
3648466eb7dSYinying Li   std::optional<std::pair<VarInfo::ID, bool>>
3658466eb7dSYinying Li   create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false);
3666b88c852SAart Bik 
36734ed07e6SYinying Li   /// Looks up or creates a variable according to the given
368ad7a6b67Swren romano   /// `Policy`.  Returns nullopt in one of two circumstances:
3696b88c852SAart Bik   /// (1) the policy says we `Must` create, yet the variable already exists;
3706b88c852SAart Bik   /// (2) the policy says we `MustNot` create, yet no such variable exists.
3716b88c852SAart Bik   /// Otherwise, if the variable already exists then it is validated against
3726b88c852SAart Bik   /// the given kind and location to ensure consistency.
3736b88c852SAart Bik   std::optional<std::pair<VarInfo::ID, bool>>
374ad7a6b67Swren romano   lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
3756b88c852SAart Bik                  VarKind vk);
3766b88c852SAart Bik 
3776b88c852SAart Bik   /// Binds the given variable to the next free `Var::Num` for its `VarKind`.
3786b88c852SAart Bik   Var bindVar(VarInfo::ID id);
3796b88c852SAart Bik 
3806b88c852SAart Bik   /// Creates a new variable of the given kind and immediately binds it.
3816b88c852SAart Bik   /// This should only be used whenever the variable is known to be unused
3826b88c852SAart Bik   /// and therefore does not have a name.
3836b88c852SAart Bik   Var bindUnusedVar(VarKind vk);
3846b88c852SAart Bik 
3856b88c852SAart Bik   InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const;
3866b88c852SAart Bik 
387889f4bf2Swren romano   /// Returns the current ranks of bound variables.  This method should
388889f4bf2Swren romano   /// only be used after the environment is "finished", since binding new
389889f4bf2Swren romano   /// variables will (semantically) invalidate any previously returned `Ranks`.
getRanks()3906b88c852SAart Bik   Ranks getRanks() const { return Ranks(nextNum); }
391b939c015SAart Bik 
392889f4bf2Swren romano   /// Gets the `Var` identified by the `VarInfo::ID`, raising an assertion
393889f4bf2Swren romano   /// failure if the variable is not bound.
getVar(VarInfo::ID id)394889f4bf2Swren romano   Var getVar(VarInfo::ID id) const { return access(id).getVar(); }
3956b88c852SAart Bik };
3966b88c852SAart Bik 
3976b88c852SAart Bik //===----------------------------------------------------------------------===//
3986b88c852SAart Bik 
3996b88c852SAart Bik } // namespace ir_detail
4006b88c852SAart Bik } // namespace sparse_tensor
4016b88c852SAart Bik } // namespace mlir
4026b88c852SAart Bik 
4036b88c852SAart Bik #endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
404