1 //===- DimLvlMap.h ----------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
10 #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
11
12 #include "Var.h"
13
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "llvm/ADT/STLForwardCompat.h"
16
17 namespace mlir {
18 namespace sparse_tensor {
19 namespace ir_detail {
20
21 //===----------------------------------------------------------------------===//
22 enum class ExprKind : bool { Dimension = false, Level = true };
23
getVarKindAllowedInExpr(ExprKind ek)24 constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) {
25 using VK = std::underlying_type_t<VarKind>;
26 return VarKind{2 * static_cast<VK>(!llvm::to_underlying(ek))};
27 }
28 static_assert(getVarKindAllowedInExpr(ExprKind::Dimension) == VarKind::Level &&
29 getVarKindAllowedInExpr(ExprKind::Level) == VarKind::Dimension);
30
31 //===----------------------------------------------------------------------===//
32 class DimLvlExpr {
33 private:
34 ExprKind kind;
35 AffineExpr expr;
36
37 public:
DimLvlExpr(ExprKind ek,AffineExpr expr)38 constexpr DimLvlExpr(ExprKind ek, AffineExpr expr) : kind(ek), expr(expr) {}
39
40 //
41 // Boolean operators.
42 //
43 constexpr bool operator==(DimLvlExpr other) const {
44 return kind == other.kind && expr == other.expr;
45 }
46 constexpr bool operator!=(DimLvlExpr other) const {
47 return !(*this == other);
48 }
49 explicit operator bool() const { return static_cast<bool>(expr); }
50
51 //
52 // RTTI support (for the `DimLvlExpr` class itself).
53 //
54 template <typename U>
55 constexpr bool isa() const;
56 template <typename U>
57 constexpr U cast() const;
58 template <typename U>
59 constexpr U dyn_cast() const;
60
61 //
62 // Simple getters.
63 //
getExprKind()64 constexpr ExprKind getExprKind() const { return kind; }
getAllowedVarKind()65 constexpr VarKind getAllowedVarKind() const {
66 return getVarKindAllowedInExpr(kind);
67 }
getAffineExpr()68 constexpr AffineExpr getAffineExpr() const { return expr; }
getAffineKind()69 AffineExprKind getAffineKind() const {
70 assert(expr);
71 return expr.getKind();
72 }
tryGetContext()73 MLIRContext *tryGetContext() const {
74 return expr ? expr.getContext() : nullptr;
75 }
76
77 //
78 // Getters for handling `AffineExpr` subclasses.
79 //
80 SymVar castSymVar() const;
81 std::optional<SymVar> dyn_castSymVar() const;
82 Var castDimLvlVar() const;
83 std::optional<Var> dyn_castDimLvlVar() const;
84 std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> unpackBinop() const;
85
86 /// Checks whether the variables bound/used by this spec are valid
87 /// with respect to the given ranks.
88 [[nodiscard]] bool isValid(Ranks const &ranks) const;
89
90 protected:
91 // Variant of `mlir::AsmPrinter::Impl::BindingStrength`
92 enum class BindingStrength : bool { Weak = false, Strong = true };
93 };
94 static_assert(IsZeroCostAbstraction<DimLvlExpr>);
95
96 class DimExpr final : public DimLvlExpr {
97 friend class DimLvlExpr;
DimExpr(DimLvlExpr expr)98 constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
99
100 public:
101 static constexpr ExprKind Kind = ExprKind::Dimension;
classof(DimLvlExpr const * expr)102 static constexpr bool classof(DimLvlExpr const *expr) {
103 return expr->getExprKind() == Kind;
104 }
DimExpr(AffineExpr expr)105 constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
106
castLvlVar()107 LvlVar castLvlVar() const { return castDimLvlVar().cast<LvlVar>(); }
dyn_castLvlVar()108 std::optional<LvlVar> dyn_castLvlVar() const {
109 const auto var = dyn_castDimLvlVar();
110 return var ? std::make_optional(var->cast<LvlVar>()) : std::nullopt;
111 }
112 };
113 static_assert(IsZeroCostAbstraction<DimExpr>);
114
115 class LvlExpr final : public DimLvlExpr {
116 friend class DimLvlExpr;
LvlExpr(DimLvlExpr expr)117 constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
118
119 public:
120 static constexpr ExprKind Kind = ExprKind::Level;
classof(DimLvlExpr const * expr)121 static constexpr bool classof(DimLvlExpr const *expr) {
122 return expr->getExprKind() == Kind;
123 }
LvlExpr(AffineExpr expr)124 constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
125
castDimVar()126 DimVar castDimVar() const { return castDimLvlVar().cast<DimVar>(); }
dyn_castDimVar()127 std::optional<DimVar> dyn_castDimVar() const {
128 const auto var = dyn_castDimLvlVar();
129 return var ? std::make_optional(var->cast<DimVar>()) : std::nullopt;
130 }
131 };
132 static_assert(IsZeroCostAbstraction<LvlExpr>);
133
134 template <typename U>
isa()135 constexpr bool DimLvlExpr::isa() const {
136 if constexpr (std::is_same_v<U, DimExpr>)
137 return getExprKind() == ExprKind::Dimension;
138 if constexpr (std::is_same_v<U, LvlExpr>)
139 return getExprKind() == ExprKind::Level;
140 }
141
142 template <typename U>
cast()143 constexpr U DimLvlExpr::cast() const {
144 assert(isa<U>());
145 return U(*this);
146 }
147
148 template <typename U>
dyn_cast()149 constexpr U DimLvlExpr::dyn_cast() const {
150 return isa<U>() ? U(*this) : U();
151 }
152
153 //===----------------------------------------------------------------------===//
154 /// The full `dimVar = dimExpr : dimSlice` specification for a given dimension.
155 class DimSpec final {
156 /// The dimension-variable bound by this specification.
157 DimVar var;
158 /// The dimension-expression. The `DimSpec` ctor treats this field
159 /// as optional; whereas the `DimLvlMap` ctor will fill in (or verify)
160 /// the expression via function-inversion inference.
161 DimExpr expr;
162 /// Can the `expr` be elided when printing? The `DimSpec` ctor assumes
163 /// not (though if `expr` is null it will elide printing that); whereas
164 /// the `DimLvlMap` ctor will reset it as appropriate.
165 bool elideExpr = false;
166 /// The dimension-slice; optional, default is null.
167 SparseTensorDimSliceAttr slice;
168
169 public:
170 DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice);
171
tryGetContext()172 MLIRContext *tryGetContext() const { return expr.tryGetContext(); }
173
getBoundVar()174 constexpr DimVar getBoundVar() const { return var; }
hasExpr()175 bool hasExpr() const { return static_cast<bool>(expr); }
getExpr()176 constexpr DimExpr getExpr() const { return expr; }
setExpr(DimExpr newExpr)177 void setExpr(DimExpr newExpr) {
178 assert(!hasExpr());
179 expr = newExpr;
180 }
canElideExpr()181 constexpr bool canElideExpr() const { return elideExpr; }
setElideExpr(bool b)182 void setElideExpr(bool b) { elideExpr = b; }
getSlice()183 constexpr SparseTensorDimSliceAttr getSlice() const { return slice; }
184
185 /// Checks whether the variables bound/used by this spec are valid with
186 /// respect to the given ranks. Note that null `DimExpr` is considered
187 /// to be vacuously valid, and therefore calling `setExpr` invalidates
188 /// the result of this predicate.
189 [[nodiscard]] bool isValid(Ranks const &ranks) const;
190 };
191
192 static_assert(IsZeroCostAbstraction<DimSpec>);
193
194 //===----------------------------------------------------------------------===//
195 /// The full `lvlVar = lvlExpr : lvlType` specification for a given level.
196 class LvlSpec final {
197 /// The level-variable bound by this specification.
198 LvlVar var;
199 /// Can the `var` be elided when printing? The `LvlSpec` ctor assumes not;
200 /// whereas the `DimLvlMap` ctor will reset this as appropriate.
201 bool elideVar = false;
202 /// The level-expression.
203 LvlExpr expr;
204 /// The level-type (== level-format + lvl-properties).
205 LevelType type;
206
207 public:
208 LvlSpec(LvlVar var, LvlExpr expr, LevelType type);
209
getContext()210 MLIRContext *getContext() const {
211 MLIRContext *ctx = expr.tryGetContext();
212 assert(ctx);
213 return ctx;
214 }
215
getBoundVar()216 constexpr LvlVar getBoundVar() const { return var; }
canElideVar()217 constexpr bool canElideVar() const { return elideVar; }
setElideVar(bool b)218 void setElideVar(bool b) { elideVar = b; }
getExpr()219 constexpr LvlExpr getExpr() const { return expr; }
getType()220 constexpr LevelType getType() const { return type; }
221
222 /// Checks whether the variables bound/used by this spec are valid
223 /// with respect to the given ranks.
224 [[nodiscard]] bool isValid(Ranks const &ranks) const;
225 };
226
227 static_assert(IsZeroCostAbstraction<LvlSpec>);
228
229 //===----------------------------------------------------------------------===//
230 class DimLvlMap final {
231 public:
232 DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
233 ArrayRef<LvlSpec> lvlSpecs);
234
getSymRank()235 unsigned getSymRank() const { return symRank; }
getDimRank()236 unsigned getDimRank() const { return dimSpecs.size(); }
getLvlRank()237 unsigned getLvlRank() const { return lvlSpecs.size(); }
getRank(VarKind vk)238 unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); }
getRanks()239 Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; }
240
getDims()241 ArrayRef<DimSpec> getDims() const { return dimSpecs; }
getDim(Dimension dim)242 const DimSpec &getDim(Dimension dim) const { return dimSpecs[dim]; }
getDimSlice(Dimension dim)243 SparseTensorDimSliceAttr getDimSlice(Dimension dim) const {
244 return getDim(dim).getSlice();
245 }
246
getLvls()247 ArrayRef<LvlSpec> getLvls() const { return lvlSpecs; }
getLvl(Level lvl)248 const LvlSpec &getLvl(Level lvl) const { return lvlSpecs[lvl]; }
getLvlType(Level lvl)249 LevelType getLvlType(Level lvl) const { return getLvl(lvl).getType(); }
250
251 AffineMap getDimToLvlMap(MLIRContext *context) const;
252 AffineMap getLvlToDimMap(MLIRContext *context) const;
253
254 private:
255 /// Checks for integrity of variable-binding structure.
256 /// This is already called by the ctor.
257 [[nodiscard]] bool isWF() const;
258
259 /// Helper function to call `DimSpec::setExpr` while asserting that
260 /// the invariant established by `DimLvlMap:isWF` is maintained.
261 /// This is used by the ctor.
setDimExpr(Dimension dim,DimExpr expr)262 void setDimExpr(Dimension dim, DimExpr expr) {
263 assert(expr && getRanks().isValid(expr));
264 dimSpecs[dim].setExpr(expr);
265 }
266
267 // All these fields are const-after-ctor.
268 unsigned symRank;
269 SmallVector<DimSpec> dimSpecs;
270 SmallVector<LvlSpec> lvlSpecs;
271 bool mustPrintLvlVars;
272 };
273
274 //===----------------------------------------------------------------------===//
275
276 } // namespace ir_detail
277 } // namespace sparse_tensor
278 } // namespace mlir
279
280 #endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
281