16b88c852SAart Bik //===- DimLvlMap.cpp ------------------------------------------------------===//
26b88c852SAart Bik //
36b88c852SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46b88c852SAart Bik // See https://llvm.org/LICENSE.txt for license information.
56b88c852SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66b88c852SAart Bik //
76b88c852SAart Bik //===----------------------------------------------------------------------===//
86b88c852SAart Bik
96b88c852SAart Bik #include "DimLvlMap.h"
106b88c852SAart Bik
116b88c852SAart Bik using namespace mlir;
126b88c852SAart Bik using namespace mlir::sparse_tensor;
136b88c852SAart Bik using namespace mlir::sparse_tensor::ir_detail;
146b88c852SAart Bik
156b88c852SAart Bik //===----------------------------------------------------------------------===//
166b88c852SAart Bik // `DimLvlExpr` implementation.
176b88c852SAart Bik //===----------------------------------------------------------------------===//
186b88c852SAart Bik
castSymVar() const196b88c852SAart Bik SymVar DimLvlExpr::castSymVar() const {
201609f1c2Slong.chen return SymVar(llvm::cast<AffineSymbolExpr>(expr));
216b88c852SAart Bik }
226b88c852SAart Bik
dyn_castSymVar() const2378921a64Swren romano std::optional<SymVar> DimLvlExpr::dyn_castSymVar() const {
241609f1c2Slong.chen if (const auto s = dyn_cast_or_null<AffineSymbolExpr>(expr))
2578921a64Swren romano return SymVar(s);
2678921a64Swren romano return std::nullopt;
2778921a64Swren romano }
2878921a64Swren romano
castDimLvlVar() const296b88c852SAart Bik Var DimLvlExpr::castDimLvlVar() const {
301609f1c2Slong.chen return Var(getAllowedVarKind(), llvm::cast<AffineDimExpr>(expr));
316b88c852SAart Bik }
326b88c852SAart Bik
dyn_castDimLvlVar() const3378921a64Swren romano std::optional<Var> DimLvlExpr::dyn_castDimLvlVar() const {
341609f1c2Slong.chen if (const auto x = dyn_cast_or_null<AffineDimExpr>(expr))
3578921a64Swren romano return Var(getAllowedVarKind(), x);
3678921a64Swren romano return std::nullopt;
3778921a64Swren romano }
3878921a64Swren romano
396b88c852SAart Bik std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr>
unpackBinop() const406b88c852SAart Bik DimLvlExpr::unpackBinop() const {
416b88c852SAart Bik const auto ak = getAffineKind();
421609f1c2Slong.chen const auto binop = llvm::dyn_cast<AffineBinaryOpExpr>(expr);
436b88c852SAart Bik const DimLvlExpr lhs(kind, binop ? binop.getLHS() : nullptr);
446b88c852SAart Bik const DimLvlExpr rhs(kind, binop ? binop.getRHS() : nullptr);
456b88c852SAart Bik return {lhs, ak, rhs};
466b88c852SAart Bik }
476b88c852SAart Bik
486b88c852SAart Bik //===----------------------------------------------------------------------===//
496b88c852SAart Bik // `DimSpec` implementation.
506b88c852SAart Bik //===----------------------------------------------------------------------===//
516b88c852SAart Bik
DimSpec(DimVar var,DimExpr expr,SparseTensorDimSliceAttr slice)526b88c852SAart Bik DimSpec::DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice)
536b88c852SAart Bik : var(var), expr(expr), slice(slice) {}
546b88c852SAart Bik
isValid(Ranks const & ranks) const556b88c852SAart Bik bool DimSpec::isValid(Ranks const &ranks) const {
563b00f448Swren romano // Nothing in `slice` needs additional validation.
573b00f448Swren romano // We explicitly consider null-expr to be vacuously valid.
583b00f448Swren romano return ranks.isValid(var) && (!expr || ranks.isValid(expr));
596b88c852SAart Bik }
606b88c852SAart Bik
616b88c852SAart Bik //===----------------------------------------------------------------------===//
626b88c852SAart Bik // `LvlSpec` implementation.
636b88c852SAart Bik //===----------------------------------------------------------------------===//
646b88c852SAart Bik
LvlSpec(LvlVar var,LvlExpr expr,LevelType type)65*1944c4f7SAart Bik LvlSpec::LvlSpec(LvlVar var, LvlExpr expr, LevelType type)
666b88c852SAart Bik : var(var), expr(expr), type(type) {
676b88c852SAart Bik assert(expr);
681dd387e1SAart Bik assert(isValidLT(type) && !isUndefLT(type));
696b88c852SAart Bik }
706b88c852SAart Bik
isValid(Ranks const & ranks) const716b88c852SAart Bik bool LvlSpec::isValid(Ranks const &ranks) const {
723b00f448Swren romano // Nothing in `type` needs additional validation.
736b88c852SAart Bik return ranks.isValid(var) && ranks.isValid(expr);
746b88c852SAart Bik }
756b88c852SAart Bik
766b88c852SAart Bik //===----------------------------------------------------------------------===//
776b88c852SAart Bik // `DimLvlMap` implementation.
786b88c852SAart Bik //===----------------------------------------------------------------------===//
796b88c852SAart Bik
DimLvlMap(unsigned symRank,ArrayRef<DimSpec> dimSpecs,ArrayRef<LvlSpec> lvlSpecs)806b88c852SAart Bik DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
816b88c852SAart Bik ArrayRef<LvlSpec> lvlSpecs)
8204959123Swren romano : symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs),
8304959123Swren romano mustPrintLvlVars(false) {
846b88c852SAart Bik // First, check integrity of the variable-binding structure.
85497050c9Swren romano // NOTE: This establishes the invariant that calls to `VarSet::add`
86497050c9Swren romano // below cannot cause OOB errors.
876b88c852SAart Bik assert(isWF());
886b88c852SAart Bik
896b88c852SAart Bik VarSet usedVars(getRanks());
90497050c9Swren romano for (const auto &dimSpec : dimSpecs)
91497050c9Swren romano if (!dimSpec.canElideExpr())
92497050c9Swren romano usedVars.add(dimSpec.getExpr());
9304959123Swren romano for (auto &lvlSpec : this->lvlSpecs) {
9404959123Swren romano // Is this LvlVar used in any overt expression?
9504959123Swren romano const bool isUsed = usedVars.contains(lvlSpec.getBoundVar());
9604959123Swren romano // This LvlVar can be elided iff it isn't overtly used.
9704959123Swren romano lvlSpec.setElideVar(!isUsed);
9804959123Swren romano // If any LvlVar cannot be elided, then must forward-declare all LvlVars.
9904959123Swren romano mustPrintLvlVars = mustPrintLvlVars || isUsed;
10004959123Swren romano }
1016b88c852SAart Bik }
1026b88c852SAart Bik
isWF() const1036b88c852SAart Bik bool DimLvlMap::isWF() const {
1046b88c852SAart Bik const auto ranks = getRanks();
1056b88c852SAart Bik unsigned dimNum = 0;
1066b88c852SAart Bik for (const auto &dimSpec : dimSpecs)
1076b88c852SAart Bik if (dimSpec.getBoundVar().getNum() != dimNum++ || !dimSpec.isValid(ranks))
1086b88c852SAart Bik return false;
1096b88c852SAart Bik assert(dimNum == ranks.getDimRank());
1106b88c852SAart Bik unsigned lvlNum = 0;
1116b88c852SAart Bik for (const auto &lvlSpec : lvlSpecs)
1126b88c852SAart Bik if (lvlSpec.getBoundVar().getNum() != lvlNum++ || !lvlSpec.isValid(ranks))
1136b88c852SAart Bik return false;
1146b88c852SAart Bik assert(lvlNum == ranks.getLvlRank());
1156b88c852SAart Bik return true;
1166b88c852SAart Bik }
1176b88c852SAart Bik
getDimToLvlMap(MLIRContext * context) const118fdbe9312Swren romano AffineMap DimLvlMap::getDimToLvlMap(MLIRContext *context) const {
119fdbe9312Swren romano SmallVector<AffineExpr> lvlAffines;
120fdbe9312Swren romano lvlAffines.reserve(getLvlRank());
121fdbe9312Swren romano for (const auto &lvlSpec : lvlSpecs)
122fdbe9312Swren romano lvlAffines.push_back(lvlSpec.getExpr().getAffineExpr());
123c3160f86Syinying-lisa-li auto map = AffineMap::get(getDimRank(), getSymRank(), lvlAffines, context);
124c3160f86Syinying-lisa-li return map;
125fdbe9312Swren romano }
126fdbe9312Swren romano
getLvlToDimMap(MLIRContext * context) const127fdbe9312Swren romano AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const {
128fdbe9312Swren romano SmallVector<AffineExpr> dimAffines;
129fdbe9312Swren romano dimAffines.reserve(getDimRank());
1307b9fb1c2SYinying Li for (const auto &dimSpec : dimSpecs) {
1317b9fb1c2SYinying Li auto expr = dimSpec.getExpr().getAffineExpr();
1327b9fb1c2SYinying Li if (expr) {
1337b9fb1c2SYinying Li dimAffines.push_back(expr);
1347b9fb1c2SYinying Li }
1357b9fb1c2SYinying Li }
136c3160f86Syinying-lisa-li auto map = AffineMap::get(getLvlRank(), getSymRank(), dimAffines, context);
137b165650aSYinying Li // If no lvlToDim map was passed in, returns a null AffineMap and infers it
138b165650aSYinying Li // in SparseTensorEncodingAttr::parse.
139b165650aSYinying Li if (dimAffines.empty())
1407b9fb1c2SYinying Li return AffineMap();
141c3160f86Syinying-lisa-li return map;
142fdbe9312Swren romano }
143fdbe9312Swren romano
1446b88c852SAart Bik //===----------------------------------------------------------------------===//
145