xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp (revision 1944c4f76b47c0b86c91845987baca24fd4775f8)
1 //===- DimLvlMap.cpp ------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "DimLvlMap.h"
10 
11 using namespace mlir;
12 using namespace mlir::sparse_tensor;
13 using namespace mlir::sparse_tensor::ir_detail;
14 
15 //===----------------------------------------------------------------------===//
16 // `DimLvlExpr` implementation.
17 //===----------------------------------------------------------------------===//
18 
castSymVar() const19 SymVar DimLvlExpr::castSymVar() const {
20   return SymVar(llvm::cast<AffineSymbolExpr>(expr));
21 }
22 
dyn_castSymVar() const23 std::optional<SymVar> DimLvlExpr::dyn_castSymVar() const {
24   if (const auto s = dyn_cast_or_null<AffineSymbolExpr>(expr))
25     return SymVar(s);
26   return std::nullopt;
27 }
28 
castDimLvlVar() const29 Var DimLvlExpr::castDimLvlVar() const {
30   return Var(getAllowedVarKind(), llvm::cast<AffineDimExpr>(expr));
31 }
32 
dyn_castDimLvlVar() const33 std::optional<Var> DimLvlExpr::dyn_castDimLvlVar() const {
34   if (const auto x = dyn_cast_or_null<AffineDimExpr>(expr))
35     return Var(getAllowedVarKind(), x);
36   return std::nullopt;
37 }
38 
39 std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr>
unpackBinop() const40 DimLvlExpr::unpackBinop() const {
41   const auto ak = getAffineKind();
42   const auto binop = llvm::dyn_cast<AffineBinaryOpExpr>(expr);
43   const DimLvlExpr lhs(kind, binop ? binop.getLHS() : nullptr);
44   const DimLvlExpr rhs(kind, binop ? binop.getRHS() : nullptr);
45   return {lhs, ak, rhs};
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // `DimSpec` implementation.
50 //===----------------------------------------------------------------------===//
51 
DimSpec(DimVar var,DimExpr expr,SparseTensorDimSliceAttr slice)52 DimSpec::DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice)
53     : var(var), expr(expr), slice(slice) {}
54 
isValid(Ranks const & ranks) const55 bool DimSpec::isValid(Ranks const &ranks) const {
56   // Nothing in `slice` needs additional validation.
57   // We explicitly consider null-expr to be vacuously valid.
58   return ranks.isValid(var) && (!expr || ranks.isValid(expr));
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // `LvlSpec` implementation.
63 //===----------------------------------------------------------------------===//
64 
LvlSpec(LvlVar var,LvlExpr expr,LevelType type)65 LvlSpec::LvlSpec(LvlVar var, LvlExpr expr, LevelType type)
66     : var(var), expr(expr), type(type) {
67   assert(expr);
68   assert(isValidLT(type) && !isUndefLT(type));
69 }
70 
isValid(Ranks const & ranks) const71 bool LvlSpec::isValid(Ranks const &ranks) const {
72   // Nothing in `type` needs additional validation.
73   return ranks.isValid(var) && ranks.isValid(expr);
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // `DimLvlMap` implementation.
78 //===----------------------------------------------------------------------===//
79 
DimLvlMap(unsigned symRank,ArrayRef<DimSpec> dimSpecs,ArrayRef<LvlSpec> lvlSpecs)80 DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
81                      ArrayRef<LvlSpec> lvlSpecs)
82     : symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs),
83       mustPrintLvlVars(false) {
84   // First, check integrity of the variable-binding structure.
85   // NOTE: This establishes the invariant that calls to `VarSet::add`
86   // below cannot cause OOB errors.
87   assert(isWF());
88 
89   VarSet usedVars(getRanks());
90   for (const auto &dimSpec : dimSpecs)
91     if (!dimSpec.canElideExpr())
92       usedVars.add(dimSpec.getExpr());
93   for (auto &lvlSpec : this->lvlSpecs) {
94     // Is this LvlVar used in any overt expression?
95     const bool isUsed = usedVars.contains(lvlSpec.getBoundVar());
96     // This LvlVar can be elided iff it isn't overtly used.
97     lvlSpec.setElideVar(!isUsed);
98     // If any LvlVar cannot be elided, then must forward-declare all LvlVars.
99     mustPrintLvlVars = mustPrintLvlVars || isUsed;
100   }
101 }
102 
isWF() const103 bool DimLvlMap::isWF() const {
104   const auto ranks = getRanks();
105   unsigned dimNum = 0;
106   for (const auto &dimSpec : dimSpecs)
107     if (dimSpec.getBoundVar().getNum() != dimNum++ || !dimSpec.isValid(ranks))
108       return false;
109   assert(dimNum == ranks.getDimRank());
110   unsigned lvlNum = 0;
111   for (const auto &lvlSpec : lvlSpecs)
112     if (lvlSpec.getBoundVar().getNum() != lvlNum++ || !lvlSpec.isValid(ranks))
113       return false;
114   assert(lvlNum == ranks.getLvlRank());
115   return true;
116 }
117 
getDimToLvlMap(MLIRContext * context) const118 AffineMap DimLvlMap::getDimToLvlMap(MLIRContext *context) const {
119   SmallVector<AffineExpr> lvlAffines;
120   lvlAffines.reserve(getLvlRank());
121   for (const auto &lvlSpec : lvlSpecs)
122     lvlAffines.push_back(lvlSpec.getExpr().getAffineExpr());
123   auto map = AffineMap::get(getDimRank(), getSymRank(), lvlAffines, context);
124   return map;
125 }
126 
getLvlToDimMap(MLIRContext * context) const127 AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const {
128   SmallVector<AffineExpr> dimAffines;
129   dimAffines.reserve(getDimRank());
130   for (const auto &dimSpec : dimSpecs) {
131     auto expr = dimSpec.getExpr().getAffineExpr();
132     if (expr) {
133       dimAffines.push_back(expr);
134     }
135   }
136   auto map = AffineMap::get(getLvlRank(), getSymRank(), dimAffines, context);
137   // If no lvlToDim map was passed in, returns a null AffineMap and infers it
138   // in SparseTensorEncodingAttr::parse.
139   if (dimAffines.empty())
140     return AffineMap();
141   return map;
142 }
143 
144 //===----------------------------------------------------------------------===//
145