xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h (revision 1944c4f76b47c0b86c91845987baca24fd4775f8)
1 //===- DimLvlMap.h ----------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
10 #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
11 
12 #include "Var.h"
13 
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "llvm/ADT/STLForwardCompat.h"
16 
17 namespace mlir {
18 namespace sparse_tensor {
19 namespace ir_detail {
20 
21 //===----------------------------------------------------------------------===//
22 enum class ExprKind : bool { Dimension = false, Level = true };
23 
getVarKindAllowedInExpr(ExprKind ek)24 constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) {
25   using VK = std::underlying_type_t<VarKind>;
26   return VarKind{2 * static_cast<VK>(!llvm::to_underlying(ek))};
27 }
28 static_assert(getVarKindAllowedInExpr(ExprKind::Dimension) == VarKind::Level &&
29               getVarKindAllowedInExpr(ExprKind::Level) == VarKind::Dimension);
30 
31 //===----------------------------------------------------------------------===//
32 class DimLvlExpr {
33 private:
34   ExprKind kind;
35   AffineExpr expr;
36 
37 public:
DimLvlExpr(ExprKind ek,AffineExpr expr)38   constexpr DimLvlExpr(ExprKind ek, AffineExpr expr) : kind(ek), expr(expr) {}
39 
40   //
41   // Boolean operators.
42   //
43   constexpr bool operator==(DimLvlExpr other) const {
44     return kind == other.kind && expr == other.expr;
45   }
46   constexpr bool operator!=(DimLvlExpr other) const {
47     return !(*this == other);
48   }
49   explicit operator bool() const { return static_cast<bool>(expr); }
50 
51   //
52   // RTTI support (for the `DimLvlExpr` class itself).
53   //
54   template <typename U>
55   constexpr bool isa() const;
56   template <typename U>
57   constexpr U cast() const;
58   template <typename U>
59   constexpr U dyn_cast() const;
60 
61   //
62   // Simple getters.
63   //
getExprKind()64   constexpr ExprKind getExprKind() const { return kind; }
getAllowedVarKind()65   constexpr VarKind getAllowedVarKind() const {
66     return getVarKindAllowedInExpr(kind);
67   }
getAffineExpr()68   constexpr AffineExpr getAffineExpr() const { return expr; }
getAffineKind()69   AffineExprKind getAffineKind() const {
70     assert(expr);
71     return expr.getKind();
72   }
tryGetContext()73   MLIRContext *tryGetContext() const {
74     return expr ? expr.getContext() : nullptr;
75   }
76 
77   //
78   // Getters for handling `AffineExpr` subclasses.
79   //
80   SymVar castSymVar() const;
81   std::optional<SymVar> dyn_castSymVar() const;
82   Var castDimLvlVar() const;
83   std::optional<Var> dyn_castDimLvlVar() const;
84   std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> unpackBinop() const;
85 
86   /// Checks whether the variables bound/used by this spec are valid
87   /// with respect to the given ranks.
88   [[nodiscard]] bool isValid(Ranks const &ranks) const;
89 
90 protected:
91   // Variant of `mlir::AsmPrinter::Impl::BindingStrength`
92   enum class BindingStrength : bool { Weak = false, Strong = true };
93 };
94 static_assert(IsZeroCostAbstraction<DimLvlExpr>);
95 
96 class DimExpr final : public DimLvlExpr {
97   friend class DimLvlExpr;
DimExpr(DimLvlExpr expr)98   constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
99 
100 public:
101   static constexpr ExprKind Kind = ExprKind::Dimension;
classof(DimLvlExpr const * expr)102   static constexpr bool classof(DimLvlExpr const *expr) {
103     return expr->getExprKind() == Kind;
104   }
DimExpr(AffineExpr expr)105   constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
106 
castLvlVar()107   LvlVar castLvlVar() const { return castDimLvlVar().cast<LvlVar>(); }
dyn_castLvlVar()108   std::optional<LvlVar> dyn_castLvlVar() const {
109     const auto var = dyn_castDimLvlVar();
110     return var ? std::make_optional(var->cast<LvlVar>()) : std::nullopt;
111   }
112 };
113 static_assert(IsZeroCostAbstraction<DimExpr>);
114 
115 class LvlExpr final : public DimLvlExpr {
116   friend class DimLvlExpr;
LvlExpr(DimLvlExpr expr)117   constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
118 
119 public:
120   static constexpr ExprKind Kind = ExprKind::Level;
classof(DimLvlExpr const * expr)121   static constexpr bool classof(DimLvlExpr const *expr) {
122     return expr->getExprKind() == Kind;
123   }
LvlExpr(AffineExpr expr)124   constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
125 
castDimVar()126   DimVar castDimVar() const { return castDimLvlVar().cast<DimVar>(); }
dyn_castDimVar()127   std::optional<DimVar> dyn_castDimVar() const {
128     const auto var = dyn_castDimLvlVar();
129     return var ? std::make_optional(var->cast<DimVar>()) : std::nullopt;
130   }
131 };
132 static_assert(IsZeroCostAbstraction<LvlExpr>);
133 
134 template <typename U>
isa()135 constexpr bool DimLvlExpr::isa() const {
136   if constexpr (std::is_same_v<U, DimExpr>)
137     return getExprKind() == ExprKind::Dimension;
138   if constexpr (std::is_same_v<U, LvlExpr>)
139     return getExprKind() == ExprKind::Level;
140 }
141 
142 template <typename U>
cast()143 constexpr U DimLvlExpr::cast() const {
144   assert(isa<U>());
145   return U(*this);
146 }
147 
148 template <typename U>
dyn_cast()149 constexpr U DimLvlExpr::dyn_cast() const {
150   return isa<U>() ? U(*this) : U();
151 }
152 
153 //===----------------------------------------------------------------------===//
154 /// The full `dimVar = dimExpr : dimSlice` specification for a given dimension.
155 class DimSpec final {
156   /// The dimension-variable bound by this specification.
157   DimVar var;
158   /// The dimension-expression.  The `DimSpec` ctor treats this field
159   /// as optional; whereas the `DimLvlMap` ctor will fill in (or verify)
160   /// the expression via function-inversion inference.
161   DimExpr expr;
162   /// Can the `expr` be elided when printing? The `DimSpec` ctor assumes
163   /// not (though if `expr` is null it will elide printing that); whereas
164   /// the `DimLvlMap` ctor will reset it as appropriate.
165   bool elideExpr = false;
166   /// The dimension-slice; optional, default is null.
167   SparseTensorDimSliceAttr slice;
168 
169 public:
170   DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice);
171 
tryGetContext()172   MLIRContext *tryGetContext() const { return expr.tryGetContext(); }
173 
getBoundVar()174   constexpr DimVar getBoundVar() const { return var; }
hasExpr()175   bool hasExpr() const { return static_cast<bool>(expr); }
getExpr()176   constexpr DimExpr getExpr() const { return expr; }
setExpr(DimExpr newExpr)177   void setExpr(DimExpr newExpr) {
178     assert(!hasExpr());
179     expr = newExpr;
180   }
canElideExpr()181   constexpr bool canElideExpr() const { return elideExpr; }
setElideExpr(bool b)182   void setElideExpr(bool b) { elideExpr = b; }
getSlice()183   constexpr SparseTensorDimSliceAttr getSlice() const { return slice; }
184 
185   /// Checks whether the variables bound/used by this spec are valid with
186   /// respect to the given ranks.  Note that null `DimExpr` is considered
187   /// to be vacuously valid, and therefore calling `setExpr` invalidates
188   /// the result of this predicate.
189   [[nodiscard]] bool isValid(Ranks const &ranks) const;
190 };
191 
192 static_assert(IsZeroCostAbstraction<DimSpec>);
193 
194 //===----------------------------------------------------------------------===//
195 /// The full `lvlVar = lvlExpr : lvlType` specification for a given level.
196 class LvlSpec final {
197   /// The level-variable bound by this specification.
198   LvlVar var;
199   /// Can the `var` be elided when printing?  The `LvlSpec` ctor assumes not;
200   /// whereas the `DimLvlMap` ctor will reset this as appropriate.
201   bool elideVar = false;
202   /// The level-expression.
203   LvlExpr expr;
204   /// The level-type (== level-format + lvl-properties).
205   LevelType type;
206 
207 public:
208   LvlSpec(LvlVar var, LvlExpr expr, LevelType type);
209 
getContext()210   MLIRContext *getContext() const {
211     MLIRContext *ctx = expr.tryGetContext();
212     assert(ctx);
213     return ctx;
214   }
215 
getBoundVar()216   constexpr LvlVar getBoundVar() const { return var; }
canElideVar()217   constexpr bool canElideVar() const { return elideVar; }
setElideVar(bool b)218   void setElideVar(bool b) { elideVar = b; }
getExpr()219   constexpr LvlExpr getExpr() const { return expr; }
getType()220   constexpr LevelType getType() const { return type; }
221 
222   /// Checks whether the variables bound/used by this spec are valid
223   /// with respect to the given ranks.
224   [[nodiscard]] bool isValid(Ranks const &ranks) const;
225 };
226 
227 static_assert(IsZeroCostAbstraction<LvlSpec>);
228 
229 //===----------------------------------------------------------------------===//
230 class DimLvlMap final {
231 public:
232   DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
233             ArrayRef<LvlSpec> lvlSpecs);
234 
getSymRank()235   unsigned getSymRank() const { return symRank; }
getDimRank()236   unsigned getDimRank() const { return dimSpecs.size(); }
getLvlRank()237   unsigned getLvlRank() const { return lvlSpecs.size(); }
getRank(VarKind vk)238   unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); }
getRanks()239   Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; }
240 
getDims()241   ArrayRef<DimSpec> getDims() const { return dimSpecs; }
getDim(Dimension dim)242   const DimSpec &getDim(Dimension dim) const { return dimSpecs[dim]; }
getDimSlice(Dimension dim)243   SparseTensorDimSliceAttr getDimSlice(Dimension dim) const {
244     return getDim(dim).getSlice();
245   }
246 
getLvls()247   ArrayRef<LvlSpec> getLvls() const { return lvlSpecs; }
getLvl(Level lvl)248   const LvlSpec &getLvl(Level lvl) const { return lvlSpecs[lvl]; }
getLvlType(Level lvl)249   LevelType getLvlType(Level lvl) const { return getLvl(lvl).getType(); }
250 
251   AffineMap getDimToLvlMap(MLIRContext *context) const;
252   AffineMap getLvlToDimMap(MLIRContext *context) const;
253 
254 private:
255   /// Checks for integrity of variable-binding structure.
256   /// This is already called by the ctor.
257   [[nodiscard]] bool isWF() const;
258 
259   /// Helper function to call `DimSpec::setExpr` while asserting that
260   /// the invariant established by `DimLvlMap:isWF` is maintained.
261   /// This is used by the ctor.
setDimExpr(Dimension dim,DimExpr expr)262   void setDimExpr(Dimension dim, DimExpr expr) {
263     assert(expr && getRanks().isValid(expr));
264     dimSpecs[dim].setExpr(expr);
265   }
266 
267   // All these fields are const-after-ctor.
268   unsigned symRank;
269   SmallVector<DimSpec> dimSpecs;
270   SmallVector<LvlSpec> lvlSpecs;
271   bool mustPrintLvlVars;
272 };
273 
274 //===----------------------------------------------------------------------===//
275 
276 } // namespace ir_detail
277 } // namespace sparse_tensor
278 } // namespace mlir
279 
280 #endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
281