xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp (revision 884221eddb9d395830704fac79fd04008e02e368)
1 //===- Var.cpp ------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Var.h"
10 #include "DimLvlMap.h"
11 
12 using namespace mlir;
13 using namespace mlir::sparse_tensor;
14 using namespace mlir::sparse_tensor::ir_detail;
15 
16 //===----------------------------------------------------------------------===//
17 // `VarKind` helpers.
18 //===----------------------------------------------------------------------===//
19 
20 /// For use in foreach loops.
21 static constexpr const VarKind everyVarKind[] = {
22     VarKind::Dimension, VarKind::Symbol, VarKind::Level};
23 
24 //===----------------------------------------------------------------------===//
25 // `Var` implementation.
26 //===----------------------------------------------------------------------===//
27 
28 std::string Var::str() const {
29   std::string str;
30   llvm::raw_string_ostream os(str);
31   print(os);
32   return str;
33 }
34 
35 void Var::print(AsmPrinter &printer) const { print(printer.getStream()); }
36 
37 void Var::print(llvm::raw_ostream &os) const {
38   os << toChar(getKind()) << getNum();
39 }
40 
41 void Var::dump() const {
42   print(llvm::errs());
43   llvm::errs() << "\n";
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // `Ranks` implementation.
48 //===----------------------------------------------------------------------===//
49 
50 bool Ranks::operator==(Ranks const &other) const {
51   for (const auto vk : everyVarKind)
52     if (getRank(vk) != other.getRank(vk))
53       return false;
54   return true;
55 }
56 
57 bool Ranks::isValid(DimLvlExpr expr) const {
58   assert(expr);
59   // Compute the maximum identifiers for symbol-vars and dim/lvl-vars
60   // (each `DimLvlExpr` only allows one kind of non-symbol variable).
61   int64_t maxSym = -1, maxVar = -1;
62   mlir::getMaxDimAndSymbol<ArrayRef<AffineExpr>>({{expr.getAffineExpr()}},
63                                                  maxVar, maxSym);
64   return maxSym < getSymRank() && maxVar < getRank(expr.getAllowedVarKind());
65 }
66 
67 //===----------------------------------------------------------------------===//
68 // `VarSet` implementation.
69 //===----------------------------------------------------------------------===//
70 
71 VarSet::VarSet(Ranks const &ranks) {
72   for (const auto vk : everyVarKind)
73     impl[vk] = llvm::SmallBitVector(ranks.getRank(vk));
74   assert(getRanks() == ranks);
75 }
76 
77 bool VarSet::contains(Var var) const {
78   // NOTE: We make sure to return false on OOB, for consistency with
79   // the `anyCommon` implementation of `VarSet::occursIn(VarSet)`.
80   // However beware that, as always with silencing OOB, this can hide
81   // bugs in client code.
82   const llvm::SmallBitVector &bits = impl[var.getKind()];
83   const auto num = var.getNum();
84   return num < bits.size() && bits[num];
85 }
86 
87 void VarSet::add(Var var) {
88   // NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB.
89   impl[var.getKind()][var.getNum()] = true;
90 }
91 
92 void VarSet::add(VarSet const &other) {
93   // NOTE: `SmallBitVector::operator&=` will implicitly resize
94   // the bitvector (unlike `BitVector::operator&=`), so we add an
95   // assertion against OOB for consistency with the implementation
96   // of `VarSet::add(Var)`.
97   for (const auto vk : everyVarKind) {
98     assert(impl[vk].size() >= other.impl[vk].size());
99     impl[vk] &= other.impl[vk];
100   }
101 }
102 
103 void VarSet::add(DimLvlExpr expr) {
104   if (!expr)
105     return;
106   switch (expr.getAffineKind()) {
107   case AffineExprKind::Constant:
108     return;
109   case AffineExprKind::SymbolId:
110     add(expr.castSymVar());
111     return;
112   case AffineExprKind::DimId:
113     add(expr.castDimLvlVar());
114     return;
115   case AffineExprKind::Add:
116   case AffineExprKind::Mul:
117   case AffineExprKind::Mod:
118   case AffineExprKind::FloorDiv:
119   case AffineExprKind::CeilDiv: {
120     const auto [lhs, op, rhs] = expr.unpackBinop();
121     (void)op;
122     add(lhs);
123     add(rhs);
124     return;
125   }
126   }
127   llvm_unreachable("unknown AffineExprKind");
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // `VarInfo` implementation.
132 //===----------------------------------------------------------------------===//
133 
134 void VarInfo::setNum(Var::Num n) {
135   assert(!hasNum() && "Var::Num is already set");
136   assert(Var::isWF_Num(n) && "Var::Num is too large");
137   num = n;
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // `VarEnv` implementation.
142 //===----------------------------------------------------------------------===//
143 
144 /// Helper function for `assertUsageConsistency` to better handle SMLoc
145 /// mismatches.
146 LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc
147 minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
148   const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1));
149   assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`");
150   const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2));
151   assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`");
152   if (loc1.getFilename() != loc2.getFilename())
153     return SMLoc();
154   const auto pair1 = std::make_pair(loc1.getLine(), loc1.getColumn());
155   const auto pair2 = std::make_pair(loc2.getLine(), loc2.getColumn());
156   return pair1 <= pair2 ? sm1 : sm2;
157 }
158 
159 bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) {
160   const auto &var = env.access(id);
161   return (var.getName() == name && var.getID() == id);
162 }
163 
164 bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc,
165                        VarKind vk) {
166   const auto &var = env.access(id);
167   return var.getKind() == vk;
168 }
169 
170 std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
171   const auto iter = ids.find(name);
172   if (iter == ids.end())
173     return std::nullopt;
174   const auto id = iter->second;
175   if (!isInternalConsistent(*this, id, name))
176     return std::nullopt;
177   return id;
178 }
179 
180 std::optional<std::pair<VarInfo::ID, bool>>
181 VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) {
182   const auto &[iter, didInsert] = ids.try_emplace(name, nextID());
183   const auto id = iter->second;
184   if (didInsert) {
185     vars.emplace_back(id, name, loc, vk);
186   } else {
187   if (!isInternalConsistent(*this, id, name))
188     return std::nullopt;
189   if (verifyUsage)
190     if (!isUsageConsistent(*this, id, loc, vk))
191       return std::nullopt;
192   }
193   return std::make_pair(id, didInsert);
194 }
195 
196 std::optional<std::pair<VarInfo::ID, bool>>
197 VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
198                        VarKind vk) {
199   switch (creationPolicy) {
200   case Policy::MustNot: {
201     const auto oid = lookup(name);
202     if (!oid)
203       return std::nullopt;  // Doesn't exist, but must not create.
204     if (!isUsageConsistent(*this, *oid, loc, vk))
205       return std::nullopt;
206     return std::make_pair(*oid, false);
207   }
208   case Policy::May:
209     return create(name, loc, vk, /*verifyUsage=*/true);
210   case Policy::Must: {
211     const auto res = create(name, loc, vk, /*verifyUsage=*/false);
212     const auto didCreate = res->second;
213     if (!didCreate)
214       return std::nullopt;  // Already exists, but must create.
215     return res;
216   }
217   }
218   llvm_unreachable("unknown Policy");
219 }
220 
221 Var VarEnv::bindUnusedVar(VarKind vk) { return Var(vk, nextNum[vk]++); }
222 Var VarEnv::bindVar(VarInfo::ID id) {
223   auto &info = access(id);
224   const auto var = bindUnusedVar(info.getKind());
225   info.setNum(var.getNum());
226   return var;
227 }
228 
229 InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const {
230   for (const auto &var : vars)
231     if (!var.hasNum())
232       return parser.emitError(var.getLoc(),
233                               "Unbound variable: " + var.getName());
234   return {};
235 }
236 
237 //===----------------------------------------------------------------------===//
238