xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp (revision 7b9fb1c228337de9c866b123ff60f3491eebd3d7)
1 //===- DimLvlMap.cpp ------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "DimLvlMap.h"
10 
11 using namespace mlir;
12 using namespace mlir::sparse_tensor;
13 using namespace mlir::sparse_tensor::ir_detail;
14 
15 //===----------------------------------------------------------------------===//
16 // `DimLvlExpr` implementation.
17 //===----------------------------------------------------------------------===//
18 
19 Var DimLvlExpr::castAnyVar() const {
20   assert(expr && "uninitialized DimLvlExpr");
21   const auto var = dyn_castAnyVar();
22   assert(var && "expected DimLvlExpr to be a Var");
23   return *var;
24 }
25 
26 std::optional<Var> DimLvlExpr::dyn_castAnyVar() const {
27   if (const auto s = expr.dyn_cast_or_null<AffineSymbolExpr>())
28     return SymVar(s);
29   if (const auto x = expr.dyn_cast_or_null<AffineDimExpr>())
30     return Var(getAllowedVarKind(), x);
31   return std::nullopt;
32 }
33 
34 SymVar DimLvlExpr::castSymVar() const {
35   return SymVar(expr.cast<AffineSymbolExpr>());
36 }
37 
38 std::optional<SymVar> DimLvlExpr::dyn_castSymVar() const {
39   if (const auto s = expr.dyn_cast_or_null<AffineSymbolExpr>())
40     return SymVar(s);
41   return std::nullopt;
42 }
43 
44 Var DimLvlExpr::castDimLvlVar() const {
45   return Var(getAllowedVarKind(), expr.cast<AffineDimExpr>());
46 }
47 
48 std::optional<Var> DimLvlExpr::dyn_castDimLvlVar() const {
49   if (const auto x = expr.dyn_cast_or_null<AffineDimExpr>())
50     return Var(getAllowedVarKind(), x);
51   return std::nullopt;
52 }
53 
54 int64_t DimLvlExpr::castConstantValue() const {
55   return expr.cast<AffineConstantExpr>().getValue();
56 }
57 
58 std::optional<int64_t> DimLvlExpr::dyn_castConstantValue() const {
59   const auto k = expr.dyn_cast_or_null<AffineConstantExpr>();
60   return k ? std::make_optional(k.getValue()) : std::nullopt;
61 }
62 
63 // This helper method is akin to `AffineExpr::operator==(int64_t)`
64 // except it uses a different implementation, namely the implementation
65 // used within `AsmPrinter::Impl::printAffineExprInternal`.
66 //
67 // wrengr guesses that `AsmPrinter::Impl::printAffineExprInternal` uses
68 // this implementation because it avoids constructing the intermediate
69 // `AffineConstantExpr(val)` and thus should in theory be a bit faster.
70 // However, if it is indeed faster, then the `AffineExpr::operator==`
71 // method should be updated to do this instead.  And if it isn't any
72 // faster, then we should be using `AffineExpr::operator==` instead.
73 bool DimLvlExpr::hasConstantValue(int64_t val) const {
74   const auto k = expr.dyn_cast_or_null<AffineConstantExpr>();
75   return k && k.getValue() == val;
76 }
77 
78 DimLvlExpr DimLvlExpr::getLHS() const {
79   const auto binop = expr.dyn_cast_or_null<AffineBinaryOpExpr>();
80   return DimLvlExpr(kind, binop ? binop.getLHS() : nullptr);
81 }
82 
83 DimLvlExpr DimLvlExpr::getRHS() const {
84   const auto binop = expr.dyn_cast_or_null<AffineBinaryOpExpr>();
85   return DimLvlExpr(kind, binop ? binop.getRHS() : nullptr);
86 }
87 
88 std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr>
89 DimLvlExpr::unpackBinop() const {
90   const auto ak = getAffineKind();
91   const auto binop = expr.dyn_cast<AffineBinaryOpExpr>();
92   const DimLvlExpr lhs(kind, binop ? binop.getLHS() : nullptr);
93   const DimLvlExpr rhs(kind, binop ? binop.getRHS() : nullptr);
94   return {lhs, ak, rhs};
95 }
96 
97 void DimLvlExpr::dump() const {
98   print(llvm::errs());
99   llvm::errs() << "\n";
100 }
101 std::string DimLvlExpr::str() const {
102   std::string str;
103   llvm::raw_string_ostream os(str);
104   print(os);
105   return os.str();
106 }
107 void DimLvlExpr::print(AsmPrinter &printer) const {
108   print(printer.getStream());
109 }
110 void DimLvlExpr::print(llvm::raw_ostream &os) const {
111   if (!expr)
112     os << "<<NULL AFFINE EXPR>>";
113   else
114     printWeak(os);
115 }
116 
117 namespace {
118 struct MatchNeg final : public std::pair<DimLvlExpr, int64_t> {
119   using Base = std::pair<DimLvlExpr, int64_t>;
120   using Base::Base;
121   constexpr DimLvlExpr getLHS() const { return first; }
122   constexpr int64_t getRHS() const { return second; }
123 };
124 } // namespace
125 
126 static std::optional<MatchNeg> matchNeg(DimLvlExpr expr) {
127   const auto [lhs, op, rhs] = expr.unpackBinop();
128   if (op == AffineExprKind::Constant) {
129     const auto val = expr.castConstantValue();
130     if (val < 0)
131       return MatchNeg{DimLvlExpr{expr.getExprKind(), AffineExpr()}, val};
132   }
133   if (op == AffineExprKind::Mul)
134     if (const auto rval = rhs.dyn_castConstantValue(); rval && *rval < 0)
135       return MatchNeg{lhs, *rval};
136   return std::nullopt;
137 }
138 
139 // A heavily revised version of `AsmPrinter::Impl::printAffineExprInternal`.
140 void DimLvlExpr::printAffineExprInternal(
141     llvm::raw_ostream &os, BindingStrength enclosingTightness) const {
142   const char *binopSpelling = nullptr;
143   switch (getAffineKind()) {
144   case AffineExprKind::SymbolId:
145     os << castSymVar();
146     return;
147   case AffineExprKind::DimId:
148     os << castDimLvlVar();
149     return;
150   case AffineExprKind::Constant:
151     os << castConstantValue();
152     return;
153   case AffineExprKind::Add:
154     binopSpelling = " + "; // N.B., this is unused
155     break;
156   case AffineExprKind::Mul:
157     binopSpelling = " * ";
158     break;
159   case AffineExprKind::FloorDiv:
160     binopSpelling = " floordiv ";
161     break;
162   case AffineExprKind::CeilDiv:
163     binopSpelling = " ceildiv ";
164     break;
165   case AffineExprKind::Mod:
166     binopSpelling = " mod ";
167     break;
168   }
169 
170   if (enclosingTightness == BindingStrength::Strong)
171     os << '(';
172 
173   const auto [lhs, op, rhs] = unpackBinop();
174   if (op == AffineExprKind::Mul && rhs.hasConstantValue(-1)) {
175     // Pretty print `(lhs * -1)` as "-lhs".
176     os << '-';
177     lhs.printStrong(os);
178   } else if (op != AffineExprKind::Add) {
179     // Default rule for tightly binding binary operators.
180     // (Including `Mul` that didn't match the previous rule.)
181     lhs.printStrong(os);
182     os << binopSpelling;
183     rhs.printStrong(os);
184   } else {
185     // Combination of all the special rules for addition/subtraction.
186     lhs.printWeak(os);
187     const auto rx = matchNeg(rhs);
188     os << (rx ? " - " : " + ");
189     const auto &rlhs = rx ? rx->getLHS() : rhs;
190     const auto rrhs = rx ? rx->getRHS() : -1; // value irrelevant when `!rx`
191     const bool nonunit = rrhs != -1;          // value irrelevant when `!rx`
192     const bool isStrong =
193         rx && rlhs && (nonunit || rlhs.getAffineKind() == AffineExprKind::Add);
194     if (rlhs)
195       rlhs.printAffineExprInternal(os, BindingStrength{isStrong});
196     if (rx && rlhs && nonunit)
197       os << " * ";
198     if (rx && (!rlhs || nonunit))
199       os << -rrhs;
200   }
201 
202   if (enclosingTightness == BindingStrength::Strong)
203     os << ')';
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // `DimSpec` implementation.
208 //===----------------------------------------------------------------------===//
209 
210 DimSpec::DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice)
211     : var(var), expr(expr), slice(slice) {}
212 
213 bool DimSpec::isValid(Ranks const &ranks) const {
214   // Nothing in `slice` needs additional validation.
215   // We explicitly consider null-expr to be vacuously valid.
216   return ranks.isValid(var) && (!expr || ranks.isValid(expr));
217 }
218 
219 bool DimSpec::isFunctionOf(VarSet const &vars) const {
220   return vars.occursIn(expr);
221 }
222 
223 void DimSpec::getFreeVars(VarSet &vars) const { vars.add(expr); }
224 
225 void DimSpec::dump() const {
226   print(llvm::errs(), /*wantElision=*/false);
227   llvm::errs() << "\n";
228 }
229 std::string DimSpec::str(bool wantElision) const {
230   std::string str;
231   llvm::raw_string_ostream os(str);
232   print(os, wantElision);
233   return os.str();
234 }
235 void DimSpec::print(AsmPrinter &printer, bool wantElision) const {
236   print(printer.getStream(), wantElision);
237 }
238 void DimSpec::print(llvm::raw_ostream &os, bool wantElision) const {
239   os << var;
240   if (expr && (!wantElision || !elideExpr))
241     os << " = " << expr;
242   if (slice) {
243     os << " : ";
244     // Call `SparseTensorDimSliceAttr::print` directly, to avoid
245     // printing the mnemonic.
246     slice.print(os);
247   }
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // `LvlSpec` implementation.
252 //===----------------------------------------------------------------------===//
253 
254 LvlSpec::LvlSpec(LvlVar var, LvlExpr expr, DimLevelType type)
255     : var(var), expr(expr), type(type) {
256   assert(expr);
257   assert(isValidDLT(type) && !isUndefDLT(type));
258 }
259 
260 bool LvlSpec::isValid(Ranks const &ranks) const {
261   // Nothing in `type` needs additional validation.
262   return ranks.isValid(var) && ranks.isValid(expr);
263 }
264 
265 bool LvlSpec::isFunctionOf(VarSet const &vars) const {
266   return vars.occursIn(expr);
267 }
268 
269 void LvlSpec::getFreeVars(VarSet &vars) const { vars.add(expr); }
270 
271 void LvlSpec::dump() const {
272   print(llvm::errs(), /*wantElision=*/false);
273   llvm::errs() << "\n";
274 }
275 std::string LvlSpec::str(bool wantElision) const {
276   std::string str;
277   llvm::raw_string_ostream os(str);
278   print(os, wantElision);
279   return os.str();
280 }
281 void LvlSpec::print(AsmPrinter &printer, bool wantElision) const {
282   print(printer.getStream(), wantElision);
283 }
284 void LvlSpec::print(llvm::raw_ostream &os, bool wantElision) const {
285   if (!wantElision || !elideVar)
286     os << var << " = ";
287   os << expr;
288   os << ": " << toMLIRString(type);
289 }
290 
291 //===----------------------------------------------------------------------===//
292 // `DimLvlMap` implementation.
293 //===----------------------------------------------------------------------===//
294 
295 DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
296                      ArrayRef<LvlSpec> lvlSpecs)
297     : symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs),
298       mustPrintLvlVars(false) {
299   // First, check integrity of the variable-binding structure.
300   // NOTE: This establishes the invariant that calls to `VarSet::add`
301   // below cannot cause OOB errors.
302   assert(isWF());
303 
304   // TODO: Second, we need to infer/validate the `lvlToDim` mapping.
305   // Along the way we should set every `DimSpec::elideExpr` according
306   // to whether the given expression is inferable or not.  Notably, this
307   // needs to happen before the code for setting every `LvlSpec::elideVar`,
308   // since if the LvlVar is only used in elided DimExpr, then the
309   // LvlVar should also be elided.
310   // NOTE: Be sure to use `DimLvlMap::setDimExpr` for setting the new exprs,
311   // to ensure that we maintain the invariant established by `isWF` above.
312 
313   // Third, we set every `LvlSpec::elideVar` according to whether that
314   // LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr).
315   // NOTE: The invariant established by `isWF` ensures that the following
316   // calls to `VarSet::add` cannot raise OOB errors.
317   VarSet usedVars(getRanks());
318   for (const auto &dimSpec : dimSpecs)
319     if (!dimSpec.canElideExpr())
320       usedVars.add(dimSpec.getExpr());
321   for (auto &lvlSpec : this->lvlSpecs) {
322     // Is this LvlVar used in any overt expression?
323     const bool isUsed = usedVars.contains(lvlSpec.getBoundVar());
324     // This LvlVar can be elided iff it isn't overtly used.
325     lvlSpec.setElideVar(!isUsed);
326     // If any LvlVar cannot be elided, then must forward-declare all LvlVars.
327     mustPrintLvlVars = mustPrintLvlVars || isUsed;
328   }
329 }
330 
331 bool DimLvlMap::isWF() const {
332   const auto ranks = getRanks();
333   unsigned dimNum = 0;
334   for (const auto &dimSpec : dimSpecs)
335     if (dimSpec.getBoundVar().getNum() != dimNum++ || !dimSpec.isValid(ranks))
336       return false;
337   assert(dimNum == ranks.getDimRank());
338   unsigned lvlNum = 0;
339   for (const auto &lvlSpec : lvlSpecs)
340     if (lvlSpec.getBoundVar().getNum() != lvlNum++ || !lvlSpec.isValid(ranks))
341       return false;
342   assert(lvlNum == ranks.getLvlRank());
343   return true;
344 }
345 
346 AffineMap DimLvlMap::getDimToLvlMap(MLIRContext *context) const {
347   SmallVector<AffineExpr> lvlAffines;
348   lvlAffines.reserve(getLvlRank());
349   for (const auto &lvlSpec : lvlSpecs)
350     lvlAffines.push_back(lvlSpec.getExpr().getAffineExpr());
351   auto map =  AffineMap::get(getDimRank(), getSymRank(), lvlAffines, context);
352   if (map.isIdentity()) return AffineMap();
353   return map;
354 }
355 
356 AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const {
357   SmallVector<AffineExpr> dimAffines;
358   dimAffines.reserve(getDimRank());
359   for (const auto &dimSpec : dimSpecs) {
360     auto expr = dimSpec.getExpr().getAffineExpr();
361     if (expr) {
362       dimAffines.push_back(expr);
363     }
364   }
365   auto map = AffineMap::get(getLvlRank(), getSymRank(), dimAffines, context);
366   if (dimAffines.empty() || map.isIdentity())
367     return AffineMap();
368   return map;
369 }
370 
371 void DimLvlMap::dump() const {
372   print(llvm::errs(), /*wantElision=*/false);
373   llvm::errs() << "\n";
374 }
375 std::string DimLvlMap::str(bool wantElision) const {
376   std::string str;
377   llvm::raw_string_ostream os(str);
378   print(os, wantElision);
379   return os.str();
380 }
381 void DimLvlMap::print(AsmPrinter &printer, bool wantElision) const {
382   print(printer.getStream(), wantElision);
383 }
384 void DimLvlMap::print(llvm::raw_ostream &os, bool wantElision) const {
385   // Symbolic identifiers.
386   // NOTE: Unlike `AffineMap` we place the SymVar bindings before the DimVar
387   // bindings, since the SymVars may occur within DimExprs and thus this
388   // ordering helps reduce potential user confusion about the scope of bidings
389   // (since it means SymVars and DimVars both bind-forward in the usual way,
390   // whereas only LvlVars have different binding rules).
391   if (symRank != 0) {
392     os << "[s0";
393     for (unsigned i = 1; i < symRank; ++i)
394       os << ", s" << i;
395     os << ']';
396   }
397 
398   // LvlVar forward-declarations.
399   if (mustPrintLvlVars) {
400     os << '{';
401     llvm::interleaveComma(
402         lvlSpecs, os, [&](LvlSpec const &spec) { os << spec.getBoundVar(); });
403     os << "} ";
404   }
405 
406   // Dimension specifiers.
407   os << '(';
408   llvm::interleaveComma(
409       dimSpecs, os, [&](DimSpec const &spec) { spec.print(os, wantElision); });
410   os << ") -> (";
411   // Level specifiers.
412   wantElision = wantElision && !mustPrintLvlVars;
413   llvm::interleaveComma(
414       lvlSpecs, os, [&](LvlSpec const &spec) { spec.print(os, wantElision); });
415   os << ')';
416 }
417 
418 //===----------------------------------------------------------------------===//
419