1 //===- Var.cpp ------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Var.h" 10 #include "DimLvlMap.h" 11 12 using namespace mlir; 13 using namespace mlir::sparse_tensor; 14 using namespace mlir::sparse_tensor::ir_detail; 15 16 //===----------------------------------------------------------------------===// 17 // `VarKind` helpers. 18 //===----------------------------------------------------------------------===// 19 20 /// For use in foreach loops. 21 static constexpr const VarKind everyVarKind[] = { 22 VarKind::Dimension, VarKind::Symbol, VarKind::Level}; 23 24 //===----------------------------------------------------------------------===// 25 // `Var` implementation. 26 //===----------------------------------------------------------------------===// 27 28 std::string Var::str() const { 29 std::string str; 30 llvm::raw_string_ostream os(str); 31 print(os); 32 return str; 33 } 34 35 void Var::print(AsmPrinter &printer) const { print(printer.getStream()); } 36 37 void Var::print(llvm::raw_ostream &os) const { 38 os << toChar(getKind()) << getNum(); 39 } 40 41 void Var::dump() const { 42 print(llvm::errs()); 43 llvm::errs() << "\n"; 44 } 45 46 //===----------------------------------------------------------------------===// 47 // `Ranks` implementation. 48 //===----------------------------------------------------------------------===// 49 50 bool Ranks::operator==(Ranks const &other) const { 51 for (const auto vk : everyVarKind) 52 if (getRank(vk) != other.getRank(vk)) 53 return false; 54 return true; 55 } 56 57 bool Ranks::isValid(DimLvlExpr expr) const { 58 assert(expr); 59 // Compute the maximum identifiers for symbol-vars and dim/lvl-vars 60 // (each `DimLvlExpr` only allows one kind of non-symbol variable). 61 int64_t maxSym = -1, maxVar = -1; 62 mlir::getMaxDimAndSymbol<ArrayRef<AffineExpr>>({{expr.getAffineExpr()}}, 63 maxVar, maxSym); 64 return maxSym < getSymRank() && maxVar < getRank(expr.getAllowedVarKind()); 65 } 66 67 //===----------------------------------------------------------------------===// 68 // `VarSet` implementation. 69 //===----------------------------------------------------------------------===// 70 71 VarSet::VarSet(Ranks const &ranks) { 72 for (const auto vk : everyVarKind) 73 impl[vk] = llvm::SmallBitVector(ranks.getRank(vk)); 74 assert(getRanks() == ranks); 75 } 76 77 bool VarSet::contains(Var var) const { 78 // NOTE: We make sure to return false on OOB, for consistency with 79 // the `anyCommon` implementation of `VarSet::occursIn(VarSet)`. 80 // However beware that, as always with silencing OOB, this can hide 81 // bugs in client code. 82 const llvm::SmallBitVector &bits = impl[var.getKind()]; 83 const auto num = var.getNum(); 84 return num < bits.size() && bits[num]; 85 } 86 87 void VarSet::add(Var var) { 88 // NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB. 89 impl[var.getKind()][var.getNum()] = true; 90 } 91 92 void VarSet::add(VarSet const &other) { 93 // NOTE: `SmallBitVector::operator&=` will implicitly resize 94 // the bitvector (unlike `BitVector::operator&=`), so we add an 95 // assertion against OOB for consistency with the implementation 96 // of `VarSet::add(Var)`. 97 for (const auto vk : everyVarKind) { 98 assert(impl[vk].size() >= other.impl[vk].size()); 99 impl[vk] &= other.impl[vk]; 100 } 101 } 102 103 void VarSet::add(DimLvlExpr expr) { 104 if (!expr) 105 return; 106 switch (expr.getAffineKind()) { 107 case AffineExprKind::Constant: 108 return; 109 case AffineExprKind::SymbolId: 110 add(expr.castSymVar()); 111 return; 112 case AffineExprKind::DimId: 113 add(expr.castDimLvlVar()); 114 return; 115 case AffineExprKind::Add: 116 case AffineExprKind::Mul: 117 case AffineExprKind::Mod: 118 case AffineExprKind::FloorDiv: 119 case AffineExprKind::CeilDiv: { 120 const auto [lhs, op, rhs] = expr.unpackBinop(); 121 (void)op; 122 add(lhs); 123 add(rhs); 124 return; 125 } 126 } 127 llvm_unreachable("unknown AffineExprKind"); 128 } 129 130 //===----------------------------------------------------------------------===// 131 // `VarInfo` implementation. 132 //===----------------------------------------------------------------------===// 133 134 void VarInfo::setNum(Var::Num n) { 135 assert(!hasNum() && "Var::Num is already set"); 136 assert(Var::isWF_Num(n) && "Var::Num is too large"); 137 num = n; 138 } 139 140 //===----------------------------------------------------------------------===// 141 // `VarEnv` implementation. 142 //===----------------------------------------------------------------------===// 143 144 /// Helper function for `assertUsageConsistency` to better handle SMLoc 145 /// mismatches. 146 LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc 147 minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) { 148 const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1)); 149 assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`"); 150 const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2)); 151 assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`"); 152 if (loc1.getFilename() != loc2.getFilename()) 153 return SMLoc(); 154 const auto pair1 = std::make_pair(loc1.getLine(), loc1.getColumn()); 155 const auto pair2 = std::make_pair(loc2.getLine(), loc2.getColumn()); 156 return pair1 <= pair2 ? sm1 : sm2; 157 } 158 159 bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) { 160 const auto &var = env.access(id); 161 return (var.getName() == name && var.getID() == id); 162 } 163 164 bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc, 165 VarKind vk) { 166 const auto &var = env.access(id); 167 return var.getKind() == vk; 168 } 169 170 std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const { 171 const auto iter = ids.find(name); 172 if (iter == ids.end()) 173 return std::nullopt; 174 const auto id = iter->second; 175 if (!isInternalConsistent(*this, id, name)) 176 return std::nullopt; 177 return id; 178 } 179 180 std::optional<std::pair<VarInfo::ID, bool>> 181 VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) { 182 const auto &[iter, didInsert] = ids.try_emplace(name, nextID()); 183 const auto id = iter->second; 184 if (didInsert) { 185 vars.emplace_back(id, name, loc, vk); 186 } else { 187 if (!isInternalConsistent(*this, id, name)) 188 return std::nullopt; 189 if (verifyUsage) 190 if (!isUsageConsistent(*this, id, loc, vk)) 191 return std::nullopt; 192 } 193 return std::make_pair(id, didInsert); 194 } 195 196 std::optional<std::pair<VarInfo::ID, bool>> 197 VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc, 198 VarKind vk) { 199 switch (creationPolicy) { 200 case Policy::MustNot: { 201 const auto oid = lookup(name); 202 if (!oid) 203 return std::nullopt; // Doesn't exist, but must not create. 204 if (!isUsageConsistent(*this, *oid, loc, vk)) 205 return std::nullopt; 206 return std::make_pair(*oid, false); 207 } 208 case Policy::May: 209 return create(name, loc, vk, /*verifyUsage=*/true); 210 case Policy::Must: { 211 const auto res = create(name, loc, vk, /*verifyUsage=*/false); 212 const auto didCreate = res->second; 213 if (!didCreate) 214 return std::nullopt; // Already exists, but must create. 215 return res; 216 } 217 } 218 llvm_unreachable("unknown Policy"); 219 } 220 221 Var VarEnv::bindUnusedVar(VarKind vk) { return Var(vk, nextNum[vk]++); } 222 Var VarEnv::bindVar(VarInfo::ID id) { 223 auto &info = access(id); 224 const auto var = bindUnusedVar(info.getKind()); 225 info.setNum(var.getNum()); 226 return var; 227 } 228 229 InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const { 230 for (const auto &var : vars) 231 if (!var.hasNum()) 232 return parser.emitError(var.getLoc(), 233 "Unbound variable: " + var.getName()); 234 return {}; 235 } 236 237 //===----------------------------------------------------------------------===// 238