16b88c852SAart Bik //===- DimLvlMap.h ----------------------------------------------*- C++ -*-===//
26b88c852SAart Bik //
36b88c852SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46b88c852SAart Bik // See https://llvm.org/LICENSE.txt for license information.
56b88c852SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66b88c852SAart Bik //
76b88c852SAart Bik //===----------------------------------------------------------------------===//
86b88c852SAart Bik
96b88c852SAart Bik #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
106b88c852SAart Bik #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
116b88c852SAart Bik
126b88c852SAart Bik #include "Var.h"
136b88c852SAart Bik
146b88c852SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
157c2ef38cSVlad Serebrennikov #include "llvm/ADT/STLForwardCompat.h"
166b88c852SAart Bik
176b88c852SAart Bik namespace mlir {
186b88c852SAart Bik namespace sparse_tensor {
196b88c852SAart Bik namespace ir_detail {
206b88c852SAart Bik
216b88c852SAart Bik //===----------------------------------------------------------------------===//
226b88c852SAart Bik enum class ExprKind : bool { Dimension = false, Level = true };
236b88c852SAart Bik
getVarKindAllowedInExpr(ExprKind ek)246b88c852SAart Bik constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) {
256b88c852SAart Bik using VK = std::underlying_type_t<VarKind>;
267c2ef38cSVlad Serebrennikov return VarKind{2 * static_cast<VK>(!llvm::to_underlying(ek))};
276b88c852SAart Bik }
286b88c852SAart Bik static_assert(getVarKindAllowedInExpr(ExprKind::Dimension) == VarKind::Level &&
296b88c852SAart Bik getVarKindAllowedInExpr(ExprKind::Level) == VarKind::Dimension);
306b88c852SAart Bik
316b88c852SAart Bik //===----------------------------------------------------------------------===//
326b88c852SAart Bik class DimLvlExpr {
336b88c852SAart Bik private:
346b88c852SAart Bik ExprKind kind;
356b88c852SAart Bik AffineExpr expr;
366b88c852SAart Bik
376b88c852SAart Bik public:
DimLvlExpr(ExprKind ek,AffineExpr expr)386b88c852SAart Bik constexpr DimLvlExpr(ExprKind ek, AffineExpr expr) : kind(ek), expr(expr) {}
396b88c852SAart Bik
406b88c852SAart Bik //
416b88c852SAart Bik // Boolean operators.
426b88c852SAart Bik //
436b88c852SAart Bik constexpr bool operator==(DimLvlExpr other) const {
446b88c852SAart Bik return kind == other.kind && expr == other.expr;
456b88c852SAart Bik }
466b88c852SAart Bik constexpr bool operator!=(DimLvlExpr other) const {
476b88c852SAart Bik return !(*this == other);
486b88c852SAart Bik }
496b88c852SAart Bik explicit operator bool() const { return static_cast<bool>(expr); }
506b88c852SAart Bik
516b88c852SAart Bik //
526b88c852SAart Bik // RTTI support (for the `DimLvlExpr` class itself).
536b88c852SAart Bik //
546b88c852SAart Bik template <typename U>
556b88c852SAart Bik constexpr bool isa() const;
566b88c852SAart Bik template <typename U>
576b88c852SAart Bik constexpr U cast() const;
586b88c852SAart Bik template <typename U>
596b88c852SAart Bik constexpr U dyn_cast() const;
606b88c852SAart Bik
616b88c852SAart Bik //
626b88c852SAart Bik // Simple getters.
636b88c852SAart Bik //
getExprKind()646b88c852SAart Bik constexpr ExprKind getExprKind() const { return kind; }
getAllowedVarKind()656b88c852SAart Bik constexpr VarKind getAllowedVarKind() const {
666b88c852SAart Bik return getVarKindAllowedInExpr(kind);
676b88c852SAart Bik }
getAffineExpr()686b88c852SAart Bik constexpr AffineExpr getAffineExpr() const { return expr; }
getAffineKind()696b88c852SAart Bik AffineExprKind getAffineKind() const {
706b88c852SAart Bik assert(expr);
716b88c852SAart Bik return expr.getKind();
726b88c852SAart Bik }
tryGetContext()73f9d50531Swren romano MLIRContext *tryGetContext() const {
74f9d50531Swren romano return expr ? expr.getContext() : nullptr;
75f9d50531Swren romano }
766b88c852SAart Bik
776b88c852SAart Bik //
786b88c852SAart Bik // Getters for handling `AffineExpr` subclasses.
796b88c852SAart Bik //
806b88c852SAart Bik SymVar castSymVar() const;
8178921a64Swren romano std::optional<SymVar> dyn_castSymVar() const;
826b88c852SAart Bik Var castDimLvlVar() const;
8378921a64Swren romano std::optional<Var> dyn_castDimLvlVar() const;
846b88c852SAart Bik std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> unpackBinop() const;
856b88c852SAart Bik
866b88c852SAart Bik /// Checks whether the variables bound/used by this spec are valid
876b88c852SAart Bik /// with respect to the given ranks.
882a682138Swren romano [[nodiscard]] bool isValid(Ranks const &ranks) const;
896b88c852SAart Bik
906b88c852SAart Bik protected:
916b88c852SAart Bik // Variant of `mlir::AsmPrinter::Impl::BindingStrength`
926b88c852SAart Bik enum class BindingStrength : bool { Weak = false, Strong = true };
936b88c852SAart Bik };
946b88c852SAart Bik static_assert(IsZeroCostAbstraction<DimLvlExpr>);
956b88c852SAart Bik
966b88c852SAart Bik class DimExpr final : public DimLvlExpr {
976b88c852SAart Bik friend class DimLvlExpr;
DimExpr(DimLvlExpr expr)986b88c852SAart Bik constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
996b88c852SAart Bik
1006b88c852SAart Bik public:
1016b88c852SAart Bik static constexpr ExprKind Kind = ExprKind::Dimension;
classof(DimLvlExpr const * expr)1026b88c852SAart Bik static constexpr bool classof(DimLvlExpr const *expr) {
1036b88c852SAart Bik return expr->getExprKind() == Kind;
1046b88c852SAart Bik }
DimExpr(AffineExpr expr)1056b88c852SAart Bik constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
10678921a64Swren romano
castLvlVar()10778921a64Swren romano LvlVar castLvlVar() const { return castDimLvlVar().cast<LvlVar>(); }
dyn_castLvlVar()10878921a64Swren romano std::optional<LvlVar> dyn_castLvlVar() const {
10978921a64Swren romano const auto var = dyn_castDimLvlVar();
11078921a64Swren romano return var ? std::make_optional(var->cast<LvlVar>()) : std::nullopt;
11178921a64Swren romano }
1126b88c852SAart Bik };
1136b88c852SAart Bik static_assert(IsZeroCostAbstraction<DimExpr>);
1146b88c852SAart Bik
1156b88c852SAart Bik class LvlExpr final : public DimLvlExpr {
1166b88c852SAart Bik friend class DimLvlExpr;
LvlExpr(DimLvlExpr expr)1176b88c852SAart Bik constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
1186b88c852SAart Bik
1196b88c852SAart Bik public:
1206b88c852SAart Bik static constexpr ExprKind Kind = ExprKind::Level;
classof(DimLvlExpr const * expr)1216b88c852SAart Bik static constexpr bool classof(DimLvlExpr const *expr) {
1226b88c852SAart Bik return expr->getExprKind() == Kind;
1236b88c852SAart Bik }
LvlExpr(AffineExpr expr)1246b88c852SAart Bik constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
12578921a64Swren romano
castDimVar()12678921a64Swren romano DimVar castDimVar() const { return castDimLvlVar().cast<DimVar>(); }
dyn_castDimVar()12778921a64Swren romano std::optional<DimVar> dyn_castDimVar() const {
12878921a64Swren romano const auto var = dyn_castDimLvlVar();
12978921a64Swren romano return var ? std::make_optional(var->cast<DimVar>()) : std::nullopt;
13078921a64Swren romano }
1316b88c852SAart Bik };
1326b88c852SAart Bik static_assert(IsZeroCostAbstraction<LvlExpr>);
1336b88c852SAart Bik
1346b88c852SAart Bik template <typename U>
isa()1356b88c852SAart Bik constexpr bool DimLvlExpr::isa() const {
1366b88c852SAart Bik if constexpr (std::is_same_v<U, DimExpr>)
1376b88c852SAart Bik return getExprKind() == ExprKind::Dimension;
1386b88c852SAart Bik if constexpr (std::is_same_v<U, LvlExpr>)
1396b88c852SAart Bik return getExprKind() == ExprKind::Level;
1406b88c852SAart Bik }
1416b88c852SAart Bik
1426b88c852SAart Bik template <typename U>
cast()1436b88c852SAart Bik constexpr U DimLvlExpr::cast() const {
1446b88c852SAart Bik assert(isa<U>());
1456b88c852SAart Bik return U(*this);
1466b88c852SAart Bik }
1476b88c852SAart Bik
1486b88c852SAart Bik template <typename U>
dyn_cast()1496b88c852SAart Bik constexpr U DimLvlExpr::dyn_cast() const {
1506b88c852SAart Bik return isa<U>() ? U(*this) : U();
1516b88c852SAart Bik }
1526b88c852SAart Bik
1536b88c852SAart Bik //===----------------------------------------------------------------------===//
1546b88c852SAart Bik /// The full `dimVar = dimExpr : dimSlice` specification for a given dimension.
1556b88c852SAart Bik class DimSpec final {
1566b88c852SAart Bik /// The dimension-variable bound by this specification.
1576b88c852SAart Bik DimVar var;
1586b88c852SAart Bik /// The dimension-expression. The `DimSpec` ctor treats this field
1596b88c852SAart Bik /// as optional; whereas the `DimLvlMap` ctor will fill in (or verify)
1606b88c852SAart Bik /// the expression via function-inversion inference.
1616b88c852SAart Bik DimExpr expr;
1626b88c852SAart Bik /// Can the `expr` be elided when printing? The `DimSpec` ctor assumes
1636b88c852SAart Bik /// not (though if `expr` is null it will elide printing that); whereas
1646b88c852SAart Bik /// the `DimLvlMap` ctor will reset it as appropriate.
1656b88c852SAart Bik bool elideExpr = false;
1666b88c852SAart Bik /// The dimension-slice; optional, default is null.
1676b88c852SAart Bik SparseTensorDimSliceAttr slice;
1686b88c852SAart Bik
1696b88c852SAart Bik public:
1706b88c852SAart Bik DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice);
1716b88c852SAart Bik
tryGetContext()172f9d50531Swren romano MLIRContext *tryGetContext() const { return expr.tryGetContext(); }
173f9d50531Swren romano
getBoundVar()1746b88c852SAart Bik constexpr DimVar getBoundVar() const { return var; }
hasExpr()1756b88c852SAart Bik bool hasExpr() const { return static_cast<bool>(expr); }
getExpr()1766b88c852SAart Bik constexpr DimExpr getExpr() const { return expr; }
setExpr(DimExpr newExpr)1776b88c852SAart Bik void setExpr(DimExpr newExpr) {
1786b88c852SAart Bik assert(!hasExpr());
1796b88c852SAart Bik expr = newExpr;
1806b88c852SAart Bik }
canElideExpr()1816b88c852SAart Bik constexpr bool canElideExpr() const { return elideExpr; }
setElideExpr(bool b)1826b88c852SAart Bik void setElideExpr(bool b) { elideExpr = b; }
getSlice()1836b88c852SAart Bik constexpr SparseTensorDimSliceAttr getSlice() const { return slice; }
1846b88c852SAart Bik
1853b00f448Swren romano /// Checks whether the variables bound/used by this spec are valid with
1863b00f448Swren romano /// respect to the given ranks. Note that null `DimExpr` is considered
1873b00f448Swren romano /// to be vacuously valid, and therefore calling `setExpr` invalidates
1883b00f448Swren romano /// the result of this predicate.
1892a682138Swren romano [[nodiscard]] bool isValid(Ranks const &ranks) const;
1906b88c852SAart Bik };
19134ed07e6SYinying Li
1926b88c852SAart Bik static_assert(IsZeroCostAbstraction<DimSpec>);
1936b88c852SAart Bik
1946b88c852SAart Bik //===----------------------------------------------------------------------===//
1956b88c852SAart Bik /// The full `lvlVar = lvlExpr : lvlType` specification for a given level.
1966b88c852SAart Bik class LvlSpec final {
1976b88c852SAart Bik /// The level-variable bound by this specification.
1986b88c852SAart Bik LvlVar var;
1996b88c852SAart Bik /// Can the `var` be elided when printing? The `LvlSpec` ctor assumes not;
2006b88c852SAart Bik /// whereas the `DimLvlMap` ctor will reset this as appropriate.
2016b88c852SAart Bik bool elideVar = false;
2026b88c852SAart Bik /// The level-expression.
2036b88c852SAart Bik LvlExpr expr;
2046b88c852SAart Bik /// The level-type (== level-format + lvl-properties).
205*1944c4f7SAart Bik LevelType type;
2066b88c852SAart Bik
2076b88c852SAart Bik public:
208*1944c4f7SAart Bik LvlSpec(LvlVar var, LvlExpr expr, LevelType type);
2096b88c852SAart Bik
getContext()210f9d50531Swren romano MLIRContext *getContext() const {
211f9d50531Swren romano MLIRContext *ctx = expr.tryGetContext();
212f9d50531Swren romano assert(ctx);
213f9d50531Swren romano return ctx;
214f9d50531Swren romano }
215f9d50531Swren romano
getBoundVar()2166b88c852SAart Bik constexpr LvlVar getBoundVar() const { return var; }
canElideVar()2176b88c852SAart Bik constexpr bool canElideVar() const { return elideVar; }
setElideVar(bool b)2186b88c852SAart Bik void setElideVar(bool b) { elideVar = b; }
getExpr()2196b88c852SAart Bik constexpr LvlExpr getExpr() const { return expr; }
getType()220*1944c4f7SAart Bik constexpr LevelType getType() const { return type; }
2216b88c852SAart Bik
2226b88c852SAart Bik /// Checks whether the variables bound/used by this spec are valid
2236b88c852SAart Bik /// with respect to the given ranks.
2242a682138Swren romano [[nodiscard]] bool isValid(Ranks const &ranks) const;
2256b88c852SAart Bik };
22634ed07e6SYinying Li
2276b88c852SAart Bik static_assert(IsZeroCostAbstraction<LvlSpec>);
2286b88c852SAart Bik
2296b88c852SAart Bik //===----------------------------------------------------------------------===//
2306b88c852SAart Bik class DimLvlMap final {
2316b88c852SAart Bik public:
2326b88c852SAart Bik DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
2336b88c852SAart Bik ArrayRef<LvlSpec> lvlSpecs);
2346b88c852SAart Bik
getSymRank()2356b88c852SAart Bik unsigned getSymRank() const { return symRank; }
getDimRank()2366b88c852SAart Bik unsigned getDimRank() const { return dimSpecs.size(); }
getLvlRank()2376b88c852SAart Bik unsigned getLvlRank() const { return lvlSpecs.size(); }
getRank(VarKind vk)2386b88c852SAart Bik unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); }
getRanks()2396b88c852SAart Bik Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; }
2406b88c852SAart Bik
getDims()241fdbe9312Swren romano ArrayRef<DimSpec> getDims() const { return dimSpecs; }
getDim(Dimension dim)242fdbe9312Swren romano const DimSpec &getDim(Dimension dim) const { return dimSpecs[dim]; }
getDimSlice(Dimension dim)243fdbe9312Swren romano SparseTensorDimSliceAttr getDimSlice(Dimension dim) const {
244fdbe9312Swren romano return getDim(dim).getSlice();
245fdbe9312Swren romano }
246fdbe9312Swren romano
getLvls()247fdbe9312Swren romano ArrayRef<LvlSpec> getLvls() const { return lvlSpecs; }
getLvl(Level lvl)248fdbe9312Swren romano const LvlSpec &getLvl(Level lvl) const { return lvlSpecs[lvl]; }
getLvlType(Level lvl)249*1944c4f7SAart Bik LevelType getLvlType(Level lvl) const { return getLvl(lvl).getType(); }
250fdbe9312Swren romano
251fdbe9312Swren romano AffineMap getDimToLvlMap(MLIRContext *context) const;
252fdbe9312Swren romano AffineMap getLvlToDimMap(MLIRContext *context) const;
2536b88c852SAart Bik
254fdbe9312Swren romano private:
255fdbe9312Swren romano /// Checks for integrity of variable-binding structure.
256fdbe9312Swren romano /// This is already called by the ctor.
257fdbe9312Swren romano [[nodiscard]] bool isWF() const;
258fdbe9312Swren romano
259fdbe9312Swren romano /// Helper function to call `DimSpec::setExpr` while asserting that
260fdbe9312Swren romano /// the invariant established by `DimLvlMap:isWF` is maintained.
261fdbe9312Swren romano /// This is used by the ctor.
setDimExpr(Dimension dim,DimExpr expr)262fdbe9312Swren romano void setDimExpr(Dimension dim, DimExpr expr) {
263fdbe9312Swren romano assert(expr && getRanks().isValid(expr));
264fdbe9312Swren romano dimSpecs[dim].setExpr(expr);
265fdbe9312Swren romano }
266fdbe9312Swren romano
267fdbe9312Swren romano // All these fields are const-after-ctor.
268fdbe9312Swren romano unsigned symRank;
269fdbe9312Swren romano SmallVector<DimSpec> dimSpecs;
270fdbe9312Swren romano SmallVector<LvlSpec> lvlSpecs;
271fdbe9312Swren romano bool mustPrintLvlVars;
2726b88c852SAart Bik };
2736b88c852SAart Bik
2746b88c852SAart Bik //===----------------------------------------------------------------------===//
2756b88c852SAart Bik
2766b88c852SAart Bik } // namespace ir_detail
2776b88c852SAart Bik } // namespace sparse_tensor
2786b88c852SAart Bik } // namespace mlir
2796b88c852SAart Bik
2806b88c852SAart Bik #endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
281