1 //===- DimLvlMap.cpp ------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "DimLvlMap.h" 10 11 using namespace mlir; 12 using namespace mlir::sparse_tensor; 13 using namespace mlir::sparse_tensor::ir_detail; 14 15 //===----------------------------------------------------------------------===// 16 // `DimLvlExpr` implementation. 17 //===----------------------------------------------------------------------===// 18 19 Var DimLvlExpr::castAnyVar() const { 20 assert(expr && "uninitialized DimLvlExpr"); 21 const auto var = dyn_castAnyVar(); 22 assert(var && "expected DimLvlExpr to be a Var"); 23 return *var; 24 } 25 26 std::optional<Var> DimLvlExpr::dyn_castAnyVar() const { 27 if (const auto s = expr.dyn_cast_or_null<AffineSymbolExpr>()) 28 return SymVar(s); 29 if (const auto x = expr.dyn_cast_or_null<AffineDimExpr>()) 30 return Var(getAllowedVarKind(), x); 31 return std::nullopt; 32 } 33 34 SymVar DimLvlExpr::castSymVar() const { 35 return SymVar(expr.cast<AffineSymbolExpr>()); 36 } 37 38 std::optional<SymVar> DimLvlExpr::dyn_castSymVar() const { 39 if (const auto s = expr.dyn_cast_or_null<AffineSymbolExpr>()) 40 return SymVar(s); 41 return std::nullopt; 42 } 43 44 Var DimLvlExpr::castDimLvlVar() const { 45 return Var(getAllowedVarKind(), expr.cast<AffineDimExpr>()); 46 } 47 48 std::optional<Var> DimLvlExpr::dyn_castDimLvlVar() const { 49 if (const auto x = expr.dyn_cast_or_null<AffineDimExpr>()) 50 return Var(getAllowedVarKind(), x); 51 return std::nullopt; 52 } 53 54 int64_t DimLvlExpr::castConstantValue() const { 55 return expr.cast<AffineConstantExpr>().getValue(); 56 } 57 58 std::optional<int64_t> DimLvlExpr::dyn_castConstantValue() const { 59 const auto k = expr.dyn_cast_or_null<AffineConstantExpr>(); 60 return k ? std::make_optional(k.getValue()) : std::nullopt; 61 } 62 63 // This helper method is akin to `AffineExpr::operator==(int64_t)` 64 // except it uses a different implementation, namely the implementation 65 // used within `AsmPrinter::Impl::printAffineExprInternal`. 66 // 67 // wrengr guesses that `AsmPrinter::Impl::printAffineExprInternal` uses 68 // this implementation because it avoids constructing the intermediate 69 // `AffineConstantExpr(val)` and thus should in theory be a bit faster. 70 // However, if it is indeed faster, then the `AffineExpr::operator==` 71 // method should be updated to do this instead. And if it isn't any 72 // faster, then we should be using `AffineExpr::operator==` instead. 73 bool DimLvlExpr::hasConstantValue(int64_t val) const { 74 const auto k = expr.dyn_cast_or_null<AffineConstantExpr>(); 75 return k && k.getValue() == val; 76 } 77 78 DimLvlExpr DimLvlExpr::getLHS() const { 79 const auto binop = expr.dyn_cast_or_null<AffineBinaryOpExpr>(); 80 return DimLvlExpr(kind, binop ? binop.getLHS() : nullptr); 81 } 82 83 DimLvlExpr DimLvlExpr::getRHS() const { 84 const auto binop = expr.dyn_cast_or_null<AffineBinaryOpExpr>(); 85 return DimLvlExpr(kind, binop ? binop.getRHS() : nullptr); 86 } 87 88 std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> 89 DimLvlExpr::unpackBinop() const { 90 const auto ak = getAffineKind(); 91 const auto binop = expr.dyn_cast<AffineBinaryOpExpr>(); 92 const DimLvlExpr lhs(kind, binop ? binop.getLHS() : nullptr); 93 const DimLvlExpr rhs(kind, binop ? binop.getRHS() : nullptr); 94 return {lhs, ak, rhs}; 95 } 96 97 void DimLvlExpr::dump() const { 98 print(llvm::errs()); 99 llvm::errs() << "\n"; 100 } 101 std::string DimLvlExpr::str() const { 102 std::string str; 103 llvm::raw_string_ostream os(str); 104 print(os); 105 return os.str(); 106 } 107 void DimLvlExpr::print(AsmPrinter &printer) const { 108 print(printer.getStream()); 109 } 110 void DimLvlExpr::print(llvm::raw_ostream &os) const { 111 if (!expr) 112 os << "<<NULL AFFINE EXPR>>"; 113 else 114 printWeak(os); 115 } 116 117 namespace { 118 struct MatchNeg final : public std::pair<DimLvlExpr, int64_t> { 119 using Base = std::pair<DimLvlExpr, int64_t>; 120 using Base::Base; 121 constexpr DimLvlExpr getLHS() const { return first; } 122 constexpr int64_t getRHS() const { return second; } 123 }; 124 } // namespace 125 126 static std::optional<MatchNeg> matchNeg(DimLvlExpr expr) { 127 const auto [lhs, op, rhs] = expr.unpackBinop(); 128 if (op == AffineExprKind::Constant) { 129 const auto val = expr.castConstantValue(); 130 if (val < 0) 131 return MatchNeg{DimLvlExpr{expr.getExprKind(), AffineExpr()}, val}; 132 } 133 if (op == AffineExprKind::Mul) 134 if (const auto rval = rhs.dyn_castConstantValue(); rval && *rval < 0) 135 return MatchNeg{lhs, *rval}; 136 return std::nullopt; 137 } 138 139 // A heavily revised version of `AsmPrinter::Impl::printAffineExprInternal`. 140 void DimLvlExpr::printAffineExprInternal( 141 llvm::raw_ostream &os, BindingStrength enclosingTightness) const { 142 const char *binopSpelling = nullptr; 143 switch (getAffineKind()) { 144 case AffineExprKind::SymbolId: 145 os << castSymVar(); 146 return; 147 case AffineExprKind::DimId: 148 os << castDimLvlVar(); 149 return; 150 case AffineExprKind::Constant: 151 os << castConstantValue(); 152 return; 153 case AffineExprKind::Add: 154 binopSpelling = " + "; // N.B., this is unused 155 break; 156 case AffineExprKind::Mul: 157 binopSpelling = " * "; 158 break; 159 case AffineExprKind::FloorDiv: 160 binopSpelling = " floordiv "; 161 break; 162 case AffineExprKind::CeilDiv: 163 binopSpelling = " ceildiv "; 164 break; 165 case AffineExprKind::Mod: 166 binopSpelling = " mod "; 167 break; 168 } 169 170 if (enclosingTightness == BindingStrength::Strong) 171 os << '('; 172 173 const auto [lhs, op, rhs] = unpackBinop(); 174 if (op == AffineExprKind::Mul && rhs.hasConstantValue(-1)) { 175 // Pretty print `(lhs * -1)` as "-lhs". 176 os << '-'; 177 lhs.printStrong(os); 178 } else if (op != AffineExprKind::Add) { 179 // Default rule for tightly binding binary operators. 180 // (Including `Mul` that didn't match the previous rule.) 181 lhs.printStrong(os); 182 os << binopSpelling; 183 rhs.printStrong(os); 184 } else { 185 // Combination of all the special rules for addition/subtraction. 186 lhs.printWeak(os); 187 const auto rx = matchNeg(rhs); 188 os << (rx ? " - " : " + "); 189 const auto &rlhs = rx ? rx->getLHS() : rhs; 190 const auto rrhs = rx ? rx->getRHS() : -1; // value irrelevant when `!rx` 191 const bool nonunit = rrhs != -1; // value irrelevant when `!rx` 192 const bool isStrong = 193 rx && rlhs && (nonunit || rlhs.getAffineKind() == AffineExprKind::Add); 194 if (rlhs) 195 rlhs.printAffineExprInternal(os, BindingStrength{isStrong}); 196 if (rx && rlhs && nonunit) 197 os << " * "; 198 if (rx && (!rlhs || nonunit)) 199 os << -rrhs; 200 } 201 202 if (enclosingTightness == BindingStrength::Strong) 203 os << ')'; 204 } 205 206 //===----------------------------------------------------------------------===// 207 // `DimSpec` implementation. 208 //===----------------------------------------------------------------------===// 209 210 DimSpec::DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice) 211 : var(var), expr(expr), slice(slice) {} 212 213 bool DimSpec::isValid(Ranks const &ranks) const { 214 // Nothing in `slice` needs additional validation. 215 // We explicitly consider null-expr to be vacuously valid. 216 return ranks.isValid(var) && (!expr || ranks.isValid(expr)); 217 } 218 219 bool DimSpec::isFunctionOf(VarSet const &vars) const { 220 return vars.occursIn(expr); 221 } 222 223 void DimSpec::getFreeVars(VarSet &vars) const { vars.add(expr); } 224 225 void DimSpec::dump() const { 226 print(llvm::errs(), /*wantElision=*/false); 227 llvm::errs() << "\n"; 228 } 229 std::string DimSpec::str(bool wantElision) const { 230 std::string str; 231 llvm::raw_string_ostream os(str); 232 print(os, wantElision); 233 return os.str(); 234 } 235 void DimSpec::print(AsmPrinter &printer, bool wantElision) const { 236 print(printer.getStream(), wantElision); 237 } 238 void DimSpec::print(llvm::raw_ostream &os, bool wantElision) const { 239 os << var; 240 if (expr && (!wantElision || !elideExpr)) 241 os << " = " << expr; 242 if (slice) { 243 os << " : "; 244 // Call `SparseTensorDimSliceAttr::print` directly, to avoid 245 // printing the mnemonic. 246 slice.print(os); 247 } 248 } 249 250 //===----------------------------------------------------------------------===// 251 // `LvlSpec` implementation. 252 //===----------------------------------------------------------------------===// 253 254 LvlSpec::LvlSpec(LvlVar var, LvlExpr expr, DimLevelType type) 255 : var(var), expr(expr), type(type) { 256 assert(expr); 257 assert(isValidDLT(type) && !isUndefDLT(type)); 258 } 259 260 bool LvlSpec::isValid(Ranks const &ranks) const { 261 // Nothing in `type` needs additional validation. 262 return ranks.isValid(var) && ranks.isValid(expr); 263 } 264 265 bool LvlSpec::isFunctionOf(VarSet const &vars) const { 266 return vars.occursIn(expr); 267 } 268 269 void LvlSpec::getFreeVars(VarSet &vars) const { vars.add(expr); } 270 271 void LvlSpec::dump() const { 272 print(llvm::errs(), /*wantElision=*/false); 273 llvm::errs() << "\n"; 274 } 275 std::string LvlSpec::str(bool wantElision) const { 276 std::string str; 277 llvm::raw_string_ostream os(str); 278 print(os, wantElision); 279 return os.str(); 280 } 281 void LvlSpec::print(AsmPrinter &printer, bool wantElision) const { 282 print(printer.getStream(), wantElision); 283 } 284 void LvlSpec::print(llvm::raw_ostream &os, bool wantElision) const { 285 if (!wantElision || !elideVar) 286 os << var << " = "; 287 os << expr; 288 os << ": " << toMLIRString(type); 289 } 290 291 //===----------------------------------------------------------------------===// 292 // `DimLvlMap` implementation. 293 //===----------------------------------------------------------------------===// 294 295 DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs, 296 ArrayRef<LvlSpec> lvlSpecs) 297 : symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs), 298 mustPrintLvlVars(false) { 299 // First, check integrity of the variable-binding structure. 300 // NOTE: This establishes the invariant that calls to `VarSet::add` 301 // below cannot cause OOB errors. 302 assert(isWF()); 303 304 // TODO: Second, we need to infer/validate the `lvlToDim` mapping. 305 // Along the way we should set every `DimSpec::elideExpr` according 306 // to whether the given expression is inferable or not. Notably, this 307 // needs to happen before the code for setting every `LvlSpec::elideVar`, 308 // since if the LvlVar is only used in elided DimExpr, then the 309 // LvlVar should also be elided. 310 // NOTE: Be sure to use `DimLvlMap::setDimExpr` for setting the new exprs, 311 // to ensure that we maintain the invariant established by `isWF` above. 312 313 // Third, we set every `LvlSpec::elideVar` according to whether that 314 // LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr). 315 // NOTE: The invariant established by `isWF` ensures that the following 316 // calls to `VarSet::add` cannot raise OOB errors. 317 VarSet usedVars(getRanks()); 318 for (const auto &dimSpec : dimSpecs) 319 if (!dimSpec.canElideExpr()) 320 usedVars.add(dimSpec.getExpr()); 321 for (auto &lvlSpec : this->lvlSpecs) { 322 // Is this LvlVar used in any overt expression? 323 const bool isUsed = usedVars.contains(lvlSpec.getBoundVar()); 324 // This LvlVar can be elided iff it isn't overtly used. 325 lvlSpec.setElideVar(!isUsed); 326 // If any LvlVar cannot be elided, then must forward-declare all LvlVars. 327 mustPrintLvlVars = mustPrintLvlVars || isUsed; 328 } 329 } 330 331 bool DimLvlMap::isWF() const { 332 const auto ranks = getRanks(); 333 unsigned dimNum = 0; 334 for (const auto &dimSpec : dimSpecs) 335 if (dimSpec.getBoundVar().getNum() != dimNum++ || !dimSpec.isValid(ranks)) 336 return false; 337 assert(dimNum == ranks.getDimRank()); 338 unsigned lvlNum = 0; 339 for (const auto &lvlSpec : lvlSpecs) 340 if (lvlSpec.getBoundVar().getNum() != lvlNum++ || !lvlSpec.isValid(ranks)) 341 return false; 342 assert(lvlNum == ranks.getLvlRank()); 343 return true; 344 } 345 346 AffineMap DimLvlMap::getDimToLvlMap(MLIRContext *context) const { 347 SmallVector<AffineExpr> lvlAffines; 348 lvlAffines.reserve(getLvlRank()); 349 for (const auto &lvlSpec : lvlSpecs) 350 lvlAffines.push_back(lvlSpec.getExpr().getAffineExpr()); 351 auto map = AffineMap::get(getDimRank(), getSymRank(), lvlAffines, context); 352 if (map.isIdentity()) return AffineMap(); 353 return map; 354 } 355 356 AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const { 357 SmallVector<AffineExpr> dimAffines; 358 dimAffines.reserve(getDimRank()); 359 for (const auto &dimSpec : dimSpecs) { 360 auto expr = dimSpec.getExpr().getAffineExpr(); 361 if (expr) { 362 dimAffines.push_back(expr); 363 } 364 } 365 auto map = AffineMap::get(getLvlRank(), getSymRank(), dimAffines, context); 366 if (dimAffines.empty() || map.isIdentity()) 367 return AffineMap(); 368 return map; 369 } 370 371 void DimLvlMap::dump() const { 372 print(llvm::errs(), /*wantElision=*/false); 373 llvm::errs() << "\n"; 374 } 375 std::string DimLvlMap::str(bool wantElision) const { 376 std::string str; 377 llvm::raw_string_ostream os(str); 378 print(os, wantElision); 379 return os.str(); 380 } 381 void DimLvlMap::print(AsmPrinter &printer, bool wantElision) const { 382 print(printer.getStream(), wantElision); 383 } 384 void DimLvlMap::print(llvm::raw_ostream &os, bool wantElision) const { 385 // Symbolic identifiers. 386 // NOTE: Unlike `AffineMap` we place the SymVar bindings before the DimVar 387 // bindings, since the SymVars may occur within DimExprs and thus this 388 // ordering helps reduce potential user confusion about the scope of bidings 389 // (since it means SymVars and DimVars both bind-forward in the usual way, 390 // whereas only LvlVars have different binding rules). 391 if (symRank != 0) { 392 os << "[s0"; 393 for (unsigned i = 1; i < symRank; ++i) 394 os << ", s" << i; 395 os << ']'; 396 } 397 398 // LvlVar forward-declarations. 399 if (mustPrintLvlVars) { 400 os << '{'; 401 llvm::interleaveComma( 402 lvlSpecs, os, [&](LvlSpec const &spec) { os << spec.getBoundVar(); }); 403 os << "} "; 404 } 405 406 // Dimension specifiers. 407 os << '('; 408 llvm::interleaveComma( 409 dimSpecs, os, [&](DimSpec const &spec) { spec.print(os, wantElision); }); 410 os << ") -> ("; 411 // Level specifiers. 412 wantElision = wantElision && !mustPrintLvlVars; 413 llvm::interleaveComma( 414 lvlSpecs, os, [&](LvlSpec const &spec) { spec.print(os, wantElision); }); 415 os << ')'; 416 } 417 418 //===----------------------------------------------------------------------===// 419