16b88c852SAart Bik //===- Var.cpp ------------------------------------------------------------===// 26b88c852SAart Bik // 36b88c852SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 46b88c852SAart Bik // See https://llvm.org/LICENSE.txt for license information. 56b88c852SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 66b88c852SAart Bik // 76b88c852SAart Bik //===----------------------------------------------------------------------===// 86b88c852SAart Bik 96b88c852SAart Bik #include "Var.h" 106b88c852SAart Bik #include "DimLvlMap.h" 116b88c852SAart Bik 126b88c852SAart Bik using namespace mlir; 136b88c852SAart Bik using namespace mlir::sparse_tensor; 146b88c852SAart Bik using namespace mlir::sparse_tensor::ir_detail; 156b88c852SAart Bik 166b88c852SAart Bik //===----------------------------------------------------------------------===// 175df63ad8Swren romano // `VarKind` helpers. 185df63ad8Swren romano //===----------------------------------------------------------------------===// 195df63ad8Swren romano 205df63ad8Swren romano /// For use in foreach loops. 215df63ad8Swren romano static constexpr const VarKind everyVarKind[] = { 225df63ad8Swren romano VarKind::Dimension, VarKind::Symbol, VarKind::Level}; 235df63ad8Swren romano 245df63ad8Swren romano //===----------------------------------------------------------------------===// 256b88c852SAart Bik // `Var` implementation. 266b88c852SAart Bik //===----------------------------------------------------------------------===// 276b88c852SAart Bik 28f5b974b7Swren romano std::string Var::str() const { 29f5b974b7Swren romano std::string str; 30f5b974b7Swren romano llvm::raw_string_ostream os(str); 31f5b974b7Swren romano print(os); 32*884221edSJOE1994 return str; 33f5b974b7Swren romano } 34f5b974b7Swren romano 356b88c852SAart Bik void Var::print(AsmPrinter &printer) const { print(printer.getStream()); } 366b88c852SAart Bik 376b88c852SAart Bik void Var::print(llvm::raw_ostream &os) const { 386b88c852SAart Bik os << toChar(getKind()) << getNum(); 396b88c852SAart Bik } 406b88c852SAart Bik 416b88c852SAart Bik void Var::dump() const { 426b88c852SAart Bik print(llvm::errs()); 436b88c852SAart Bik llvm::errs() << "\n"; 446b88c852SAart Bik } 456b88c852SAart Bik 466b88c852SAart Bik //===----------------------------------------------------------------------===// 476b88c852SAart Bik // `Ranks` implementation. 486b88c852SAart Bik //===----------------------------------------------------------------------===// 496b88c852SAart Bik 505df63ad8Swren romano bool Ranks::operator==(Ranks const &other) const { 515df63ad8Swren romano for (const auto vk : everyVarKind) 525df63ad8Swren romano if (getRank(vk) != other.getRank(vk)) 535df63ad8Swren romano return false; 545df63ad8Swren romano return true; 555df63ad8Swren romano } 565df63ad8Swren romano 576b88c852SAart Bik bool Ranks::isValid(DimLvlExpr expr) const { 583b00f448Swren romano assert(expr); 593b00f448Swren romano // Compute the maximum identifiers for symbol-vars and dim/lvl-vars 603b00f448Swren romano // (each `DimLvlExpr` only allows one kind of non-symbol variable). 616b88c852SAart Bik int64_t maxSym = -1, maxVar = -1; 626b88c852SAart Bik mlir::getMaxDimAndSymbol<ArrayRef<AffineExpr>>({{expr.getAffineExpr()}}, 636b88c852SAart Bik maxVar, maxSym); 646b88c852SAart Bik return maxSym < getSymRank() && maxVar < getRank(expr.getAllowedVarKind()); 656b88c852SAart Bik } 666b88c852SAart Bik 676b88c852SAart Bik //===----------------------------------------------------------------------===// 686b88c852SAart Bik // `VarSet` implementation. 696b88c852SAart Bik //===----------------------------------------------------------------------===// 706b88c852SAart Bik 716b88c852SAart Bik VarSet::VarSet(Ranks const &ranks) { 726b88c852SAart Bik for (const auto vk : everyVarKind) 73497050c9Swren romano impl[vk] = llvm::SmallBitVector(ranks.getRank(vk)); 745df63ad8Swren romano assert(getRanks() == ranks); 756b88c852SAart Bik } 766b88c852SAart Bik 776b88c852SAart Bik bool VarSet::contains(Var var) const { 78dcadb68aSwren romano // NOTE: We make sure to return false on OOB, for consistency with 79dcadb68aSwren romano // the `anyCommon` implementation of `VarSet::occursIn(VarSet)`. 80dcadb68aSwren romano // However beware that, as always with silencing OOB, this can hide 81dcadb68aSwren romano // bugs in client code. 826b88c852SAart Bik const llvm::SmallBitVector &bits = impl[var.getKind()]; 83dcadb68aSwren romano const auto num = var.getNum(); 84dcadb68aSwren romano return num < bits.size() && bits[num]; 856b88c852SAart Bik } 866b88c852SAart Bik 876b88c852SAart Bik void VarSet::add(Var var) { 88fdbe9312Swren romano // NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB. 896b88c852SAart Bik impl[var.getKind()][var.getNum()] = true; 906b88c852SAart Bik } 916b88c852SAart Bik 92dcadb68aSwren romano void VarSet::add(VarSet const &other) { 93dcadb68aSwren romano // NOTE: `SmallBitVector::operator&=` will implicitly resize 94dcadb68aSwren romano // the bitvector (unlike `BitVector::operator&=`), so we add an 95dcadb68aSwren romano // assertion against OOB for consistency with the implementation 96dcadb68aSwren romano // of `VarSet::add(Var)`. 97dcadb68aSwren romano for (const auto vk : everyVarKind) { 98dcadb68aSwren romano assert(impl[vk].size() >= other.impl[vk].size()); 99dcadb68aSwren romano impl[vk] &= other.impl[vk]; 100dcadb68aSwren romano } 101dcadb68aSwren romano } 1026b88c852SAart Bik 1036b88c852SAart Bik void VarSet::add(DimLvlExpr expr) { 1046b88c852SAart Bik if (!expr) 1056b88c852SAart Bik return; 1066b88c852SAart Bik switch (expr.getAffineKind()) { 1076b88c852SAart Bik case AffineExprKind::Constant: 1086b88c852SAart Bik return; 1096b88c852SAart Bik case AffineExprKind::SymbolId: 1106b88c852SAart Bik add(expr.castSymVar()); 1116b88c852SAart Bik return; 1126b88c852SAart Bik case AffineExprKind::DimId: 1136b88c852SAart Bik add(expr.castDimLvlVar()); 1146b88c852SAart Bik return; 1156b88c852SAart Bik case AffineExprKind::Add: 1166b88c852SAart Bik case AffineExprKind::Mul: 1176b88c852SAart Bik case AffineExprKind::Mod: 1186b88c852SAart Bik case AffineExprKind::FloorDiv: 1196b88c852SAart Bik case AffineExprKind::CeilDiv: { 1206b88c852SAart Bik const auto [lhs, op, rhs] = expr.unpackBinop(); 1216b88c852SAart Bik (void)op; 1226b88c852SAart Bik add(lhs); 1236b88c852SAart Bik add(rhs); 1246b88c852SAart Bik return; 1256b88c852SAart Bik } 1266b88c852SAart Bik } 1276b88c852SAart Bik llvm_unreachable("unknown AffineExprKind"); 1286b88c852SAart Bik } 1296b88c852SAart Bik 1306b88c852SAart Bik //===----------------------------------------------------------------------===// 1316b88c852SAart Bik // `VarInfo` implementation. 1326b88c852SAart Bik //===----------------------------------------------------------------------===// 1336b88c852SAart Bik 1346b88c852SAart Bik void VarInfo::setNum(Var::Num n) { 1356b88c852SAart Bik assert(!hasNum() && "Var::Num is already set"); 1366b88c852SAart Bik assert(Var::isWF_Num(n) && "Var::Num is too large"); 1376b88c852SAart Bik num = n; 1386b88c852SAart Bik } 1396b88c852SAart Bik 1406b88c852SAart Bik //===----------------------------------------------------------------------===// 1416b88c852SAart Bik // `VarEnv` implementation. 1426b88c852SAart Bik //===----------------------------------------------------------------------===// 1436b88c852SAart Bik 1446b88c852SAart Bik /// Helper function for `assertUsageConsistency` to better handle SMLoc 1456b88c852SAart Bik /// mismatches. 1466b88c852SAart Bik LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc 1476b88c852SAart Bik minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) { 148a5757c5bSChristian Sigg const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1)); 1496b88c852SAart Bik assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`"); 150a5757c5bSChristian Sigg const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2)); 1516b88c852SAart Bik assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`"); 1526b88c852SAart Bik if (loc1.getFilename() != loc2.getFilename()) 1536b88c852SAart Bik return SMLoc(); 1546b88c852SAart Bik const auto pair1 = std::make_pair(loc1.getLine(), loc1.getColumn()); 1556b88c852SAart Bik const auto pair2 = std::make_pair(loc2.getLine(), loc2.getColumn()); 1566b88c852SAart Bik return pair1 <= pair2 ? sm1 : sm2; 1576b88c852SAart Bik } 1586b88c852SAart Bik 1598466eb7dSYinying Li bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) { 1606b88c852SAart Bik const auto &var = env.access(id); 1618466eb7dSYinying Li return (var.getName() == name && var.getID() == id); 1626b88c852SAart Bik } 1636b88c852SAart Bik 1648466eb7dSYinying Li bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc, 1656b88c852SAart Bik VarKind vk) { 1666b88c852SAart Bik const auto &var = env.access(id); 1678466eb7dSYinying Li return var.getKind() == vk; 1686b88c852SAart Bik } 1696b88c852SAart Bik 1706b88c852SAart Bik std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const { 1716b88c852SAart Bik const auto iter = ids.find(name); 1726b88c852SAart Bik if (iter == ids.end()) 1736b88c852SAart Bik return std::nullopt; 1746b88c852SAart Bik const auto id = iter->second; 1758466eb7dSYinying Li if (!isInternalConsistent(*this, id, name)) 1768466eb7dSYinying Li return std::nullopt; 1776b88c852SAart Bik return id; 1786b88c852SAart Bik } 1796b88c852SAart Bik 1808466eb7dSYinying Li std::optional<std::pair<VarInfo::ID, bool>> 1818466eb7dSYinying Li VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) { 1826b88c852SAart Bik const auto &[iter, didInsert] = ids.try_emplace(name, nextID()); 1836b88c852SAart Bik const auto id = iter->second; 1846b88c852SAart Bik if (didInsert) { 1856b88c852SAart Bik vars.emplace_back(id, name, loc, vk); 1866b88c852SAart Bik } else { 1878466eb7dSYinying Li if (!isInternalConsistent(*this, id, name)) 1888466eb7dSYinying Li return std::nullopt; 1896b88c852SAart Bik if (verifyUsage) 1908466eb7dSYinying Li if (!isUsageConsistent(*this, id, loc, vk)) 1918466eb7dSYinying Li return std::nullopt; 1926b88c852SAart Bik } 1936b88c852SAart Bik return std::make_pair(id, didInsert); 1946b88c852SAart Bik } 1956b88c852SAart Bik 1966b88c852SAart Bik std::optional<std::pair<VarInfo::ID, bool>> 197ad7a6b67Swren romano VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc, 1986b88c852SAart Bik VarKind vk) { 199ad7a6b67Swren romano switch (creationPolicy) { 200ad7a6b67Swren romano case Policy::MustNot: { 2016b88c852SAart Bik const auto oid = lookup(name); 2026b88c852SAart Bik if (!oid) 2036b88c852SAart Bik return std::nullopt; // Doesn't exist, but must not create. 2048466eb7dSYinying Li if (!isUsageConsistent(*this, *oid, loc, vk)) 2058466eb7dSYinying Li return std::nullopt; 2066b88c852SAart Bik return std::make_pair(*oid, false); 2076b88c852SAart Bik } 208ad7a6b67Swren romano case Policy::May: 2096b88c852SAart Bik return create(name, loc, vk, /*verifyUsage=*/true); 210ad7a6b67Swren romano case Policy::Must: { 2116b88c852SAart Bik const auto res = create(name, loc, vk, /*verifyUsage=*/false); 2128466eb7dSYinying Li const auto didCreate = res->second; 2136b88c852SAart Bik if (!didCreate) 2146b88c852SAart Bik return std::nullopt; // Already exists, but must create. 2156b88c852SAart Bik return res; 2166b88c852SAart Bik } 2176b88c852SAart Bik } 218ad7a6b67Swren romano llvm_unreachable("unknown Policy"); 2196b88c852SAart Bik } 2206b88c852SAart Bik 2216b88c852SAart Bik Var VarEnv::bindUnusedVar(VarKind vk) { return Var(vk, nextNum[vk]++); } 2226b88c852SAart Bik Var VarEnv::bindVar(VarInfo::ID id) { 2236b88c852SAart Bik auto &info = access(id); 2246b88c852SAart Bik const auto var = bindUnusedVar(info.getKind()); 2256b88c852SAart Bik info.setNum(var.getNum()); 2266b88c852SAart Bik return var; 2276b88c852SAart Bik } 2286b88c852SAart Bik 2296b88c852SAart Bik InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const { 2306b88c852SAart Bik for (const auto &var : vars) 2316b88c852SAart Bik if (!var.hasNum()) 2326b88c852SAart Bik return parser.emitError(var.getLoc(), 2336b88c852SAart Bik "Unbound variable: " + var.getName()); 2346b88c852SAart Bik return {}; 2356b88c852SAart Bik } 2366b88c852SAart Bik 2376b88c852SAart Bik //===----------------------------------------------------------------------===// 238