xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp (revision 884221eddb9d395830704fac79fd04008e02e368)
16b88c852SAart Bik //===- Var.cpp ------------------------------------------------------------===//
26b88c852SAart Bik //
36b88c852SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46b88c852SAart Bik // See https://llvm.org/LICENSE.txt for license information.
56b88c852SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66b88c852SAart Bik //
76b88c852SAart Bik //===----------------------------------------------------------------------===//
86b88c852SAart Bik 
96b88c852SAart Bik #include "Var.h"
106b88c852SAart Bik #include "DimLvlMap.h"
116b88c852SAart Bik 
126b88c852SAart Bik using namespace mlir;
136b88c852SAart Bik using namespace mlir::sparse_tensor;
146b88c852SAart Bik using namespace mlir::sparse_tensor::ir_detail;
156b88c852SAart Bik 
166b88c852SAart Bik //===----------------------------------------------------------------------===//
175df63ad8Swren romano // `VarKind` helpers.
185df63ad8Swren romano //===----------------------------------------------------------------------===//
195df63ad8Swren romano 
205df63ad8Swren romano /// For use in foreach loops.
215df63ad8Swren romano static constexpr const VarKind everyVarKind[] = {
225df63ad8Swren romano     VarKind::Dimension, VarKind::Symbol, VarKind::Level};
235df63ad8Swren romano 
245df63ad8Swren romano //===----------------------------------------------------------------------===//
256b88c852SAart Bik // `Var` implementation.
266b88c852SAart Bik //===----------------------------------------------------------------------===//
276b88c852SAart Bik 
28f5b974b7Swren romano std::string Var::str() const {
29f5b974b7Swren romano   std::string str;
30f5b974b7Swren romano   llvm::raw_string_ostream os(str);
31f5b974b7Swren romano   print(os);
32*884221edSJOE1994   return str;
33f5b974b7Swren romano }
34f5b974b7Swren romano 
356b88c852SAart Bik void Var::print(AsmPrinter &printer) const { print(printer.getStream()); }
366b88c852SAart Bik 
376b88c852SAart Bik void Var::print(llvm::raw_ostream &os) const {
386b88c852SAart Bik   os << toChar(getKind()) << getNum();
396b88c852SAart Bik }
406b88c852SAart Bik 
416b88c852SAart Bik void Var::dump() const {
426b88c852SAart Bik   print(llvm::errs());
436b88c852SAart Bik   llvm::errs() << "\n";
446b88c852SAart Bik }
456b88c852SAart Bik 
466b88c852SAart Bik //===----------------------------------------------------------------------===//
476b88c852SAart Bik // `Ranks` implementation.
486b88c852SAart Bik //===----------------------------------------------------------------------===//
496b88c852SAart Bik 
505df63ad8Swren romano bool Ranks::operator==(Ranks const &other) const {
515df63ad8Swren romano   for (const auto vk : everyVarKind)
525df63ad8Swren romano     if (getRank(vk) != other.getRank(vk))
535df63ad8Swren romano       return false;
545df63ad8Swren romano   return true;
555df63ad8Swren romano }
565df63ad8Swren romano 
576b88c852SAart Bik bool Ranks::isValid(DimLvlExpr expr) const {
583b00f448Swren romano   assert(expr);
593b00f448Swren romano   // Compute the maximum identifiers for symbol-vars and dim/lvl-vars
603b00f448Swren romano   // (each `DimLvlExpr` only allows one kind of non-symbol variable).
616b88c852SAart Bik   int64_t maxSym = -1, maxVar = -1;
626b88c852SAart Bik   mlir::getMaxDimAndSymbol<ArrayRef<AffineExpr>>({{expr.getAffineExpr()}},
636b88c852SAart Bik                                                  maxVar, maxSym);
646b88c852SAart Bik   return maxSym < getSymRank() && maxVar < getRank(expr.getAllowedVarKind());
656b88c852SAart Bik }
666b88c852SAart Bik 
676b88c852SAart Bik //===----------------------------------------------------------------------===//
686b88c852SAart Bik // `VarSet` implementation.
696b88c852SAart Bik //===----------------------------------------------------------------------===//
706b88c852SAart Bik 
716b88c852SAart Bik VarSet::VarSet(Ranks const &ranks) {
726b88c852SAart Bik   for (const auto vk : everyVarKind)
73497050c9Swren romano     impl[vk] = llvm::SmallBitVector(ranks.getRank(vk));
745df63ad8Swren romano   assert(getRanks() == ranks);
756b88c852SAart Bik }
766b88c852SAart Bik 
776b88c852SAart Bik bool VarSet::contains(Var var) const {
78dcadb68aSwren romano   // NOTE: We make sure to return false on OOB, for consistency with
79dcadb68aSwren romano   // the `anyCommon` implementation of `VarSet::occursIn(VarSet)`.
80dcadb68aSwren romano   // However beware that, as always with silencing OOB, this can hide
81dcadb68aSwren romano   // bugs in client code.
826b88c852SAart Bik   const llvm::SmallBitVector &bits = impl[var.getKind()];
83dcadb68aSwren romano   const auto num = var.getNum();
84dcadb68aSwren romano   return num < bits.size() && bits[num];
856b88c852SAart Bik }
866b88c852SAart Bik 
876b88c852SAart Bik void VarSet::add(Var var) {
88fdbe9312Swren romano   // NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB.
896b88c852SAart Bik   impl[var.getKind()][var.getNum()] = true;
906b88c852SAart Bik }
916b88c852SAart Bik 
92dcadb68aSwren romano void VarSet::add(VarSet const &other) {
93dcadb68aSwren romano   // NOTE: `SmallBitVector::operator&=` will implicitly resize
94dcadb68aSwren romano   // the bitvector (unlike `BitVector::operator&=`), so we add an
95dcadb68aSwren romano   // assertion against OOB for consistency with the implementation
96dcadb68aSwren romano   // of `VarSet::add(Var)`.
97dcadb68aSwren romano   for (const auto vk : everyVarKind) {
98dcadb68aSwren romano     assert(impl[vk].size() >= other.impl[vk].size());
99dcadb68aSwren romano     impl[vk] &= other.impl[vk];
100dcadb68aSwren romano   }
101dcadb68aSwren romano }
1026b88c852SAart Bik 
1036b88c852SAart Bik void VarSet::add(DimLvlExpr expr) {
1046b88c852SAart Bik   if (!expr)
1056b88c852SAart Bik     return;
1066b88c852SAart Bik   switch (expr.getAffineKind()) {
1076b88c852SAart Bik   case AffineExprKind::Constant:
1086b88c852SAart Bik     return;
1096b88c852SAart Bik   case AffineExprKind::SymbolId:
1106b88c852SAart Bik     add(expr.castSymVar());
1116b88c852SAart Bik     return;
1126b88c852SAart Bik   case AffineExprKind::DimId:
1136b88c852SAart Bik     add(expr.castDimLvlVar());
1146b88c852SAart Bik     return;
1156b88c852SAart Bik   case AffineExprKind::Add:
1166b88c852SAart Bik   case AffineExprKind::Mul:
1176b88c852SAart Bik   case AffineExprKind::Mod:
1186b88c852SAart Bik   case AffineExprKind::FloorDiv:
1196b88c852SAart Bik   case AffineExprKind::CeilDiv: {
1206b88c852SAart Bik     const auto [lhs, op, rhs] = expr.unpackBinop();
1216b88c852SAart Bik     (void)op;
1226b88c852SAart Bik     add(lhs);
1236b88c852SAart Bik     add(rhs);
1246b88c852SAart Bik     return;
1256b88c852SAart Bik   }
1266b88c852SAart Bik   }
1276b88c852SAart Bik   llvm_unreachable("unknown AffineExprKind");
1286b88c852SAart Bik }
1296b88c852SAart Bik 
1306b88c852SAart Bik //===----------------------------------------------------------------------===//
1316b88c852SAart Bik // `VarInfo` implementation.
1326b88c852SAart Bik //===----------------------------------------------------------------------===//
1336b88c852SAart Bik 
1346b88c852SAart Bik void VarInfo::setNum(Var::Num n) {
1356b88c852SAart Bik   assert(!hasNum() && "Var::Num is already set");
1366b88c852SAart Bik   assert(Var::isWF_Num(n) && "Var::Num is too large");
1376b88c852SAart Bik   num = n;
1386b88c852SAart Bik }
1396b88c852SAart Bik 
1406b88c852SAart Bik //===----------------------------------------------------------------------===//
1416b88c852SAart Bik // `VarEnv` implementation.
1426b88c852SAart Bik //===----------------------------------------------------------------------===//
1436b88c852SAart Bik 
1446b88c852SAart Bik /// Helper function for `assertUsageConsistency` to better handle SMLoc
1456b88c852SAart Bik /// mismatches.
1466b88c852SAart Bik LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc
1476b88c852SAart Bik minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
148a5757c5bSChristian Sigg   const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1));
1496b88c852SAart Bik   assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`");
150a5757c5bSChristian Sigg   const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2));
1516b88c852SAart Bik   assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`");
1526b88c852SAart Bik   if (loc1.getFilename() != loc2.getFilename())
1536b88c852SAart Bik     return SMLoc();
1546b88c852SAart Bik   const auto pair1 = std::make_pair(loc1.getLine(), loc1.getColumn());
1556b88c852SAart Bik   const auto pair2 = std::make_pair(loc2.getLine(), loc2.getColumn());
1566b88c852SAart Bik   return pair1 <= pair2 ? sm1 : sm2;
1576b88c852SAart Bik }
1586b88c852SAart Bik 
1598466eb7dSYinying Li bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) {
1606b88c852SAart Bik   const auto &var = env.access(id);
1618466eb7dSYinying Li   return (var.getName() == name && var.getID() == id);
1626b88c852SAart Bik }
1636b88c852SAart Bik 
1648466eb7dSYinying Li bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc,
1656b88c852SAart Bik                        VarKind vk) {
1666b88c852SAart Bik   const auto &var = env.access(id);
1678466eb7dSYinying Li   return var.getKind() == vk;
1686b88c852SAart Bik }
1696b88c852SAart Bik 
1706b88c852SAart Bik std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
1716b88c852SAart Bik   const auto iter = ids.find(name);
1726b88c852SAart Bik   if (iter == ids.end())
1736b88c852SAart Bik     return std::nullopt;
1746b88c852SAart Bik   const auto id = iter->second;
1758466eb7dSYinying Li   if (!isInternalConsistent(*this, id, name))
1768466eb7dSYinying Li     return std::nullopt;
1776b88c852SAart Bik   return id;
1786b88c852SAart Bik }
1796b88c852SAart Bik 
1808466eb7dSYinying Li std::optional<std::pair<VarInfo::ID, bool>>
1818466eb7dSYinying Li VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) {
1826b88c852SAart Bik   const auto &[iter, didInsert] = ids.try_emplace(name, nextID());
1836b88c852SAart Bik   const auto id = iter->second;
1846b88c852SAart Bik   if (didInsert) {
1856b88c852SAart Bik     vars.emplace_back(id, name, loc, vk);
1866b88c852SAart Bik   } else {
1878466eb7dSYinying Li   if (!isInternalConsistent(*this, id, name))
1888466eb7dSYinying Li     return std::nullopt;
1896b88c852SAart Bik   if (verifyUsage)
1908466eb7dSYinying Li     if (!isUsageConsistent(*this, id, loc, vk))
1918466eb7dSYinying Li       return std::nullopt;
1926b88c852SAart Bik   }
1936b88c852SAart Bik   return std::make_pair(id, didInsert);
1946b88c852SAart Bik }
1956b88c852SAart Bik 
1966b88c852SAart Bik std::optional<std::pair<VarInfo::ID, bool>>
197ad7a6b67Swren romano VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
1986b88c852SAart Bik                        VarKind vk) {
199ad7a6b67Swren romano   switch (creationPolicy) {
200ad7a6b67Swren romano   case Policy::MustNot: {
2016b88c852SAart Bik     const auto oid = lookup(name);
2026b88c852SAart Bik     if (!oid)
2036b88c852SAart Bik       return std::nullopt;  // Doesn't exist, but must not create.
2048466eb7dSYinying Li     if (!isUsageConsistent(*this, *oid, loc, vk))
2058466eb7dSYinying Li       return std::nullopt;
2066b88c852SAart Bik     return std::make_pair(*oid, false);
2076b88c852SAart Bik   }
208ad7a6b67Swren romano   case Policy::May:
2096b88c852SAart Bik     return create(name, loc, vk, /*verifyUsage=*/true);
210ad7a6b67Swren romano   case Policy::Must: {
2116b88c852SAart Bik     const auto res = create(name, loc, vk, /*verifyUsage=*/false);
2128466eb7dSYinying Li     const auto didCreate = res->second;
2136b88c852SAart Bik     if (!didCreate)
2146b88c852SAart Bik       return std::nullopt;  // Already exists, but must create.
2156b88c852SAart Bik     return res;
2166b88c852SAart Bik   }
2176b88c852SAart Bik   }
218ad7a6b67Swren romano   llvm_unreachable("unknown Policy");
2196b88c852SAart Bik }
2206b88c852SAart Bik 
2216b88c852SAart Bik Var VarEnv::bindUnusedVar(VarKind vk) { return Var(vk, nextNum[vk]++); }
2226b88c852SAart Bik Var VarEnv::bindVar(VarInfo::ID id) {
2236b88c852SAart Bik   auto &info = access(id);
2246b88c852SAart Bik   const auto var = bindUnusedVar(info.getKind());
2256b88c852SAart Bik   info.setNum(var.getNum());
2266b88c852SAart Bik   return var;
2276b88c852SAart Bik }
2286b88c852SAart Bik 
2296b88c852SAart Bik InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const {
2306b88c852SAart Bik   for (const auto &var : vars)
2316b88c852SAart Bik     if (!var.hasNum())
2326b88c852SAart Bik       return parser.emitError(var.getLoc(),
2336b88c852SAart Bik                               "Unbound variable: " + var.getName());
2346b88c852SAart Bik   return {};
2356b88c852SAart Bik }
2366b88c852SAart Bik 
2376b88c852SAart Bik //===----------------------------------------------------------------------===//
238