1 //===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines utilities for dealing with iteration lattices. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ 14 #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ 15 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 18 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 19 #include "mlir/IR/Value.h" 20 #include "llvm/ADT/BitVector.h" 21 22 #include <optional> 23 24 namespace mlir { 25 namespace sparse_tensor { 26 27 namespace detail { 28 /// A constant serving as the canonically invalid identifier, 29 /// regardless of the identifier type. 30 static constexpr unsigned kInvalidId = -1u; 31 } // namespace detail 32 33 /// Tensor identifiers, chosen to be the `BlockArgument::getArgNumber` 34 /// of the value passed to `Merger::buildTensorExp`. 35 using TensorId = unsigned; 36 37 /// Loop identifiers. 38 using LoopId = unsigned; 39 40 /// A compressed representation of `std::pair<TensorId, LoopId>`. 41 /// The compression scheme is such that this also serves as an index 42 /// into the bitvector stored in `LatPoint` (since that bitvector is 43 /// just the implementation for a set of `TensorLoopId` values). 44 using TensorLoopId = unsigned; 45 46 /// `TensorExp` identifiers. These are allocated by `Merger::addExp`, 47 /// and serve as unique identifiers for the corresponding `TensorExp` object. 48 using ExprId = unsigned; 49 50 /// `LatPoint` identifiers. These are allocated by `Merger::addLat`, 51 /// and serve as unique identifiers for the corresponding `LatPoint` object. 52 using LatPointId = unsigned; 53 54 /// `LatSet` identifiers. These are allocated by `Merger::addSet` (and 55 /// by other methods calling that one), and serve as unique identifiers 56 /// for the corresponding `SmallVector<LatPointId>` object. 57 using LatSetId = unsigned; 58 59 /// A pair of level and its corresponding LevelType of a tensor. 60 using LvlLTPair = std::pair<Level, LevelType>; 61 62 /// A pair of loop id and its coefficients. E.g., for affine expression in the 63 /// affine map `2 * d0`, loop id = 0, coefficient = 2. 64 using LoopCoeffPair = std::pair<LoopId, unsigned>; 65 66 /// Tensor expression. Represents an MLIR expression in tensor index notation. 67 struct TensorExp final { 68 enum class Kind; 69 70 /// Child subexpressions for non-leaf expressions. 71 struct Children final { 72 ExprId e0; 73 ExprId e1; 74 }; 75 76 /// The `x` parameter has different types depending on the value of the 77 /// `k` parameter. The correspondences are: 78 /// * `kTensor` -> `TensorId` 79 /// * `kInvariant` -> `kInvalidId` 80 /// * `kLoopVar` -> `LoopId` 81 /// * else -> `ExprId` 82 /// 83 /// The `y`, `v`, and `op` parameters either must or must not be 84 /// `kInvalidId`/`nullptr`, depending on the value of the `k` parameter; 85 /// however, they have uniform C++ types regardless of the value of `k`. 86 TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a); 87 88 /// Tensor expression kind. 89 Kind kind; 90 91 union { 92 /// `kTensor` expressions simply have a tensor identifier. 93 TensorId tensor; 94 95 /// `kLoopVar` expressions simply have a loop identifier. 96 LoopId loop; 97 98 /// All other expressions hold the `ExprId`s of their children. 99 Children children; 100 }; 101 102 /// Direct link to IR for an invariant or the destination value (to 103 /// infer destination type) of a cast operation During code generation, 104 /// this field may be used to cache "hoisted" loop invariant tensor loads. 105 Value val; 106 107 /// Code blocks used by semirings. For the case of kUnary, kBinary, kReduce, 108 /// and kSelect, this holds the original operation with all regions. For 109 /// kBinaryBranch, this holds the YieldOp for the left or right half 110 /// to be merged into a nested scf loop. 111 /// 112 /// Or the actual operation that we can not sparsify but having all dense 113 /// operands for kDenseOp. 114 Operation *op; 115 116 /// An optional attribute that is required to determine the semantics of the 117 /// operations. E.g., CmpPredicateAttr for CmpI/CmpF operations. 118 Attribute attr; 119 }; 120 121 /// Tensor expression kind. 122 /// 123 /// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`. 124 /// That is, its argument is a `LoopId` identifying the loop-variable 125 /// in question, and its value will be the current iteration's value. 126 /// The `kSynZero` leaf kind is for representing a synthetic zero value, 127 /// which can be introduced when sparsifying operations like `arith::cmp` 128 /// to generate `arith::cmp %lhs, %syn_zero` when the rhs operand is absent. 129 enum class TensorExp::Kind { 130 // Leaf. 131 kTensor = 0, 132 kSynZero, 133 kInvariant, 134 kLoopVar, 135 // Unary operations. 136 kAbsF, 137 kAbsC, 138 kAbsI, 139 kCeilF, 140 kFloorF, 141 kSqrtF, 142 kSqrtC, 143 kExpm1F, 144 kExpm1C, 145 kLog1pF, 146 kLog1pC, 147 kRelu, 148 kSinF, 149 kSinC, 150 kTanhF, 151 kTanhC, 152 kNegF, 153 kNegC, 154 kNegI, 155 kTruncF, 156 kExtF, 157 kCastFS, // signed 158 kCastFU, // unsigned 159 kCastSF, // signed 160 kCastUF, // unsigned 161 kCastS, // signed 162 kCastU, // unsigned 163 kCastIdx, 164 kTruncI, 165 kCIm, // complex.im 166 kCRe, // complex.re 167 kBitCast, 168 kBinaryBranch, // semiring unary branch created from a binary op 169 kUnary, // semiring unary op 170 kSelect, // custom selection criteria 171 // Binary operations. 172 kMulF, 173 kMulC, 174 kMulI, 175 kDivF, 176 kDivC, // complex 177 kDivS, // signed 178 kDivU, // unsigned 179 kAddF, 180 kAddC, 181 kAddI, 182 kSubF, 183 kSubC, 184 kSubI, 185 kAndI, 186 kOrI, 187 kXorI, 188 kCmpI, 189 kCmpF, 190 kShrS, // signed 191 kShrU, // unsigned 192 kShlI, 193 kBinary, // semiring binary op 194 kReduce, // semiring reduction op 195 kDenseOp, // special category of operations requiring all dense operands 196 }; 197 198 /// Lattice point. Each lattice point consists of a formal conjunction 199 /// of `TensorLoopId`s, together with the identifier of the corresponding 200 /// tensor expression. The formal conjunction is represented as a set of 201 /// `TensorLoopId`, where that set is implemented as a `BitVector`. 202 struct LatPoint final { 203 /// Construct a lattice point with the empty set of `TensorLoopId`s. LatPointfinal204 LatPoint(unsigned size, ExprId e) : bits(size, false), exp(e) {} 205 206 /// Construct a lattice point from the given set of `TensorLoopId`s. LatPointfinal207 LatPoint(const BitVector &bits, ExprId e) : bits(bits), exp(e) {} 208 209 /// Conjunction of all `TensorLoopId`s involved in the tensor expression. 210 BitVector bits; 211 212 /// Simplified conjunction of `TensorLoopId` as bitvector. This 213 /// represents a simplified condition under which this tensor expression 214 /// must execute. Pre-computed during codegen to avoid repeated eval. 215 BitVector simple; 216 217 /// Identifier of the tensor expression. 218 ExprId exp; 219 }; 220 221 /// A class to handle all iteration lattice operations. This class abstracts 222 /// away from some implementation details of storing iteration lattices and 223 /// tensor expressions. This allows for fine-tuning performance characteristics 224 /// independently from the basic algorithm if bottlenecks are identified. 225 class Merger { 226 public: 227 /// Constructs a merger for the given number of tensors and loops. The user 228 /// supplies the number of tensors involved in the kernel, with the last 229 /// tensor in this set denoting the output tensor. The merger adds an 230 /// additional synthetic tensor at the end of this set to represent all 231 /// invariant expressions in the kernel. 232 /// 233 /// The maxLvlRank specifies the max level rank of all inputs/output tensors. 234 /// It is used to pre-allocate sufficient memory for internal storage. 235 Merger(unsigned numInputOutputTensors, unsigned numLoops, 236 unsigned maxLvlRank); 237 238 // 239 // Constructing valid tensor and loop identifiers. 240 // 241 242 /// Safely converts the argument to a tensor identifier. makeTensorId(unsigned t)243 constexpr TensorId makeTensorId(unsigned t) const { 244 assert(isValidTensorId(t)); 245 return t; 246 } 247 248 /// Safely converts the argument to a loop identifier. makeLoopId(unsigned i)249 constexpr LoopId makeLoopId(unsigned i) const { 250 assert(isValidLoopId(i)); 251 return i; 252 } 253 254 /// Safely converts the arguments to a pair of (tensor,loop) identifiers. makeTensorLoopId(unsigned t,unsigned i)255 constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const { 256 assert(isValidTensorId(t) && isValidLoopId(i)); 257 return numTensors * i + t; 258 } 259 260 // 261 // Allocating new expressions, points, and sets. 262 // 263 264 /// Constructs a new tensor expression, and returns its identifier. 265 ExprId addTensorExp(TensorId t); 266 /// Constructs a new loop-variable expression, and returns its identifier. 267 ExprId addLoopVarExp(LoopId i); 268 /// Constructs a new invariant expression, and returns its identifier. 269 ExprId addInvariantExp(Value v); 270 /// Constructs a new synthetic zero expression. 271 ExprId addSynZeroExp(); 272 /// Constructs a new unary or binary expression, and returns its identifier. 273 ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1 = detail::kInvalidId, 274 Operation *op = nullptr, Attribute attr = nullptr); 275 /// Constructs a new sesquinary expression, and returns its identifier. 276 /// Currently no sesquinary `Kind` allows specifying the `op`, but we 277 /// allow it anyways because `mapSet` is designed to allow it. 278 ExprId addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op = nullptr, 279 Attribute attr = nullptr); 280 281 /// Constructs a new iteration lattice point, and returns its identifier. 282 LatPointId addLat(TensorId t, LoopId i, ExprId e); 283 LatPointId addLat(const BitVector &bits, ExprId e); 284 285 /// Constructs a new (initially empty) set, and returns its identifier. 286 LatSetId addSet(); 287 288 /// Computes a single conjunction of two lattice points by taking the "union" 289 /// of `LoopId` (effectively constructing a larger "intersection" of those 290 /// loops) with a newly constructed tensor (sub)expression of given kind. 291 /// Returns the identifier of the new lattice point. 292 LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, 293 Operation *op = nullptr); 294 295 /// Conjunctive merge of two lattice sets: `(s0 /\_op s1)`. 296 /// Returns the identifier of the new set. 297 LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op = nullptr); 298 299 /// Disjunctive merge of two lattice sets: `(s0 /\_op s1, s0, s1)`. 300 /// Returns the identifier of the new set. 301 LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op = nullptr); 302 303 /// Disjunctive merge of two lattice sets and also set one of the operand to 304 /// zero: `(s0 /\_op s1 (e0 op e1), s0 (0 op e0), s1 (e1 op 0))`. 305 /// Returns the identifier of the new set. 306 LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1); 307 308 /// Disjunctive merge of two lattice sets with custom handling of the 309 /// overlap, left, and right regions. Any region may be left missing 310 /// in the output. Returns the identifier of the new set. 311 LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, 312 bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, 313 bool includeRight, TensorExp::Kind rtrans, 314 Operation *opright); 315 316 /// Maps the unary operator over the lattice set of the operand, i.e. each 317 /// lattice point on an expression E is simply copied over, but with OP E 318 /// as new expression. Returns the identifier of the new set. 319 LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v = Value(), 320 Operation *op = nullptr, Attribute attr = nullptr); 321 322 /// Maps the binary operator to the same operation but with one of its operand 323 /// set to zero, i.e. each lattice point on an expression E is simply copied 324 /// over, but with `OP 0 E` (if lhsZero == true) or `OP E 0` (if lhsZero == 325 /// false) as new expression. Returns the identifier of the new set. 326 LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero); 327 328 /// Optimizes the iteration lattice points in the given set. This 329 /// method should be called right before code generation to avoid 330 /// generating redundant loops and conditions. 331 LatSetId optimizeSet(LatSetId s); 332 333 /// Simplifies the conditions in a conjunction of a given lattice point 334 /// within the given set using just two basic rules: 335 /// (1) multiple dense conditions are reduced to single dense, and 336 /// (2) a *singleton* sparse/dense is reduced to sparse/random access. 337 BitVector simplifyCond(LatSetId s, LatPointId p); 338 339 /// Returns true if p0 > p1. 340 bool latGT(LatPointId p0, LatPointId p1) const; 341 342 /// Returns true if p0 and p1 only differ in dense. 343 bool onlyDenseDiff(LatPointId p0, LatPointId p1) const; 344 345 /// Gets the tensor-identifier of the `TensorLoopId`. tensor(TensorLoopId b)346 constexpr TensorId tensor(TensorLoopId b) const { return b % numTensors; } 347 /// Gets the loop-identifier of the `TensorLoopId`. loop(TensorLoopId b)348 constexpr LoopId loop(TensorLoopId b) const { return b / numTensors; } 349 350 /// Gets the total number of tensors (including the output-tensor and 351 /// synthetic-tensor). getNumTensors()352 constexpr unsigned getNumTensors() const { return numTensors; } 353 354 /// Gets the total number of loops (native loops + filter loops). getNumLoops()355 constexpr unsigned getNumLoops() const { return numLoops; } 356 357 /// Returns true if `b` is the `i`th loop of the output tensor. isOutTensor(TensorLoopId b,LoopId i)358 constexpr bool isOutTensor(TensorLoopId b, LoopId i) const { 359 return b == makeTensorLoopId(outTensor, i); 360 } 361 362 /// Gets the output tensor's identifier. getOutTensorID()363 constexpr TensorId getOutTensorID() const { return outTensor; } 364 365 /// Gets the synthetic tensor's identifier (used for all invariant 366 /// tensor expressions). getSynTensorID()367 constexpr TensorId getSynTensorID() const { return syntheticTensor; } 368 369 /// Returns true if the expression is `(kTensor t)`. expIsTensor(ExprId e,TensorId t)370 bool expIsTensor(ExprId e, TensorId t) const { 371 const auto &expr = exp(e); 372 return expr.kind == TensorExp::Kind::kTensor && expr.tensor == t; 373 } 374 375 /// Returns true if the expression contains the tensor as an operand. 376 bool expContainsTensor(ExprId e, TensorId t) const; 377 378 /// Returns true if the expression contains a negation on output tensor. 379 /// I.e., `- outTensor` or `exp - outputTensor` 380 /// NOTE: this is an trivial tests in that it does not handle recursive 381 /// negation, i.e., it returns true when the expression is `-(-tensor)`. 382 bool hasNegateOnOut(ExprId e) const; 383 384 /// Returns true if given tensor iterates *only* in the given tensor 385 /// expression. For the output tensor, this defines a "simply dynamic" 386 /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for 387 /// sparse vector a. 388 bool isSingleCondition(TensorId t, ExprId e) const; 389 390 /// Returns true if any `TensorLoopId` in the bitvector corresponds 391 /// to sparse level-type. 392 bool hasAnySparse(const BitVector &bits) const; 393 394 /// Returns true if bits contains a dependent index reduction condition on 395 /// sparse levels. 396 bool hasSparseIdxReduction(const BitVector &bits) const; 397 398 /// Gets the level-type of the `t`th tensor on `i`th loop. getLvlType(TensorId t,LoopId i)399 LevelType getLvlType(TensorId t, LoopId i) const { 400 assert(isValidTensorId(t) && isValidLoopId(i)); 401 return lvlTypes[t][i]; 402 } 403 404 /// Gets the level-type of the TensorLoopId. getLvlType(TensorLoopId b)405 LevelType getLvlType(TensorLoopId b) const { 406 return getLvlType(tensor(b), loop(b)); 407 } 408 409 /// Gets the loop identifier for the `lvl`th level of the `t`th tensor. getLoopId(TensorId t,Level lvl)410 std::optional<LoopId> getLoopId(TensorId t, Level lvl) const { 411 assert(isValidLevel(t, lvl)); 412 return lvlToLoop[t][lvl]; 413 } 414 415 /// Gets the level number of the the `t`th tensor on `i`th loop. getLvl(TensorId t,LoopId i)416 std::optional<Level> getLvl(TensorId t, LoopId i) const { 417 assert(isValidTensorId(t) && isValidLoopId(i)); 418 return loopToLvl[t][i]; 419 } getLvl(TensorLoopId b)420 std::optional<Level> getLvl(TensorLoopId b) const { 421 return getLvl(tensor(b), loop(b)); 422 } 423 424 /// Sets the level number and level-type of the `t`th tensor on 425 /// `i`th loop. setLevelAndType(TensorId t,LoopId i,Level lvl,LevelType lt)426 void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt) { 427 assert(isValidLevel(t, lvl) && isValidLoopId(i) && isValidLT(lt)); 428 lvlTypes[t][i] = lt; 429 loopToLvl[t][i] = lvl; 430 lvlToLoop[t][lvl] = i; 431 // TODO: favor a constant loop bound when there are multiple choices. 432 loopBounds[i] = std::make_pair(t, lvl); 433 } 434 435 using ForeachTensorLoopIdCallback = function_ref<void( 436 TensorLoopId, TensorId, std::optional<Level>, LevelType, bool)>; 437 438 /// Iterates over a set of `TensorLoopId`s, invoking the callback 439 /// for each `TensorLoopId` and passing it the corresponding tensor 440 /// identifier, level, and level-type, following with a boolean value 441 /// indicating whether it is a dependent index reduction loop condition. foreachTensorLoopId(LatPointId p,ForeachTensorLoopIdCallback callback)442 void foreachTensorLoopId(LatPointId p, 443 ForeachTensorLoopIdCallback callback) const { 444 // TODO: the default ought to be simple=true; but we'll need to make 445 // sure to update all the tests to make sure they do the right thing. 446 foreachTensorLoopId(p, /*simple=*/false, callback); 447 } foreachTensorLoopId(LatPointId p,bool simple,ForeachTensorLoopIdCallback callback)448 void foreachTensorLoopId(LatPointId p, bool simple, 449 ForeachTensorLoopIdCallback callback) const { 450 const auto &point = lat(p); 451 const auto &bits = simple ? point.simple : point.bits; 452 for (const TensorLoopId b : bits.set_bits()) { 453 const TensorId t = tensor(b); 454 const auto optLvl = getLvl(b); 455 const auto lvlTp = getLvlType(b); 456 if (isLvlWithNonTrivialIdxExp(b)) { 457 // This must be an undefined level. 458 assert(!optLvl.has_value()); 459 // Slice the tid along the dependent level to iterate current loop. 460 callback(b, t, getLoopDependentLevel(b), lvlTp, 461 /*isIdxReduc=*/true); 462 } else { 463 callback(b, t, optLvl, lvlTp, /*isIdxReduc=*/false); 464 } 465 } 466 } 467 468 /// Sets whether the output tensor is sparse or not. setHasSparseOut(bool s)469 void setHasSparseOut(bool s) { hasSparseOut = s; } 470 471 /// Establishes the two-way map that i <-> <t, lvl, lt>. setLoopDependentTensorLevel(LoopId i,TensorId t,Level lvl,LevelType lt,unsigned coefficient)472 void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, 473 LevelType lt, unsigned coefficient) { 474 assert(isValidLoopId(i) && isValidLevel(t, lvl)); 475 assert(!loopToUnresolvedLvls[i][t].has_value()); // must be the first def 476 loopToUnresolvedLvls[i][t] = std::make_pair(lvl, lt); 477 levelToDependentLoop[t][lvl].emplace_back(i, coefficient); 478 } 479 480 /// Whether the loop has dependent slice. hasDependentLvl(LoopId i,TensorId t)481 bool hasDependentLvl(LoopId i, TensorId t) { 482 assert(isValidTensorId(t) && isValidLoopId(i)); 483 return loopToUnresolvedLvls[i][t].has_value(); 484 } 485 486 /// Returns the list of loop indices which appear in the non-trivial index 487 /// expression on t_l, e.g., A[i+j] => {i, j} getDependentLoops(TensorId t,Level lvl)488 std::vector<LoopCoeffPair> &getDependentLoops(TensorId t, Level lvl) { 489 assert(isValidLevel(t, lvl)); 490 return levelToDependentLoop[t][lvl]; 491 } 492 493 /// Returns the defining [tid, lvl] for the loop. getLoopDefiningLvl(LoopId i)494 std::pair<TensorId, Level> getLoopDefiningLvl(LoopId i) const { 495 assert(isValidLoopId(i)); 496 return loopBounds[i]; 497 } 498 499 /// Checks whether the TensorLoopId represents a tensor level contains 500 /// non-trivial index expression. isLvlWithNonTrivialIdxExp(TensorLoopId b)501 bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const { 502 const TensorId t = tensor(b); 503 const LoopId i = loop(b); 504 assert(isValidTensorId(t) && isValidLoopId(i)); 505 return loopToUnresolvedLvls[i][t].has_value(); 506 } 507 508 /// Checks whether the TensorLoopId represents a sparse tensor level contains 509 /// non-trivial index expression. isSparseLvlWithNonTrivialIdxExp(TensorLoopId b)510 bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const { 511 if (isLvlWithNonTrivialIdxExp(b)) { 512 auto lt = getLoopDependentLevelType(b); 513 return lt.hasSparseSemantic(); 514 } 515 return false; 516 } 517 getLoopDependentLevel(TensorLoopId b)518 Level getLoopDependentLevel(TensorLoopId b) const { 519 assert(isLvlWithNonTrivialIdxExp(b)); 520 return loopToUnresolvedLvls[loop(b)][tensor(b)]->first; 521 } 522 getLoopDependentLevelType(TensorLoopId b)523 LevelType getLoopDependentLevelType(TensorLoopId b) const { 524 assert(isLvlWithNonTrivialIdxExp(b)); 525 return loopToUnresolvedLvls[loop(b)][tensor(b)]->second; 526 } 527 528 /// Convenience getters to immediately access the stored nodes. 529 /// These methods return `const&` because the underlying objects must 530 /// not be mutated by client code. The only exception is for mutating 531 /// the value associated with an expression, for which there are 532 /// dedicated methods below. 533 /// 534 /// NOTE: It is inadvisable to keep the reference alive for a long 535 /// time (e.g., as in `TensorExpr &te = merger.exp(e)`), since insertions 536 /// into the merger can cause data movement which will invalidate the 537 /// underlying memory address. This isn't just a problem with the `&` 538 /// references, but also applies to the `ArrayRef`. In particular, 539 /// using `for (LatPointId p : merger.set(s))` will run into the same 540 /// dangling-reference problems if the loop body inserts new sets. exp(ExprId e)541 const TensorExp &exp(ExprId e) const { 542 assert(isValidExprId(e)); 543 return tensorExps[e]; 544 } lat(LatPointId p)545 const LatPoint &lat(LatPointId p) const { 546 assert(isValidLatPointId(p)); 547 return latPoints[p]; 548 } set(LatSetId s)549 ArrayRef<LatPointId> set(LatSetId s) const { 550 assert(isValidLatSetId(s)); 551 return latSets[s]; 552 } 553 554 /// Checks whether the given expression has an associated value. hasExprValue(ExprId e)555 bool hasExprValue(ExprId e) const { return static_cast<bool>(exp(e).val); } 556 557 /// Sets the expression to have the associated value. Asserts that the new 558 /// value is defined, and that the expression does not already have a value. setExprValue(ExprId e,Value v)559 void setExprValue(ExprId e, Value v) { 560 assert(!exp(e).val && "Expression already has an associated value"); 561 assert(v && "Trying to assign an undefined value"); 562 tensorExps[e].val = v; 563 } 564 565 /// Clears the value associated with the expression. Asserts that the 566 /// expression does indeed have an associated value before clearing it. clearExprValue(ExprId e)567 void clearExprValue(ExprId e) { 568 assert(exp(e).val && "Expression does not have an associated value"); 569 tensorExps[e].val = Value(); 570 } 571 572 #ifndef NDEBUG 573 /// Print methods (for debugging). 574 void dumpExp(ExprId e) const; 575 void dumpLat(LatPointId p) const; 576 void dumpSet(LatSetId s) const; 577 void dumpBits(const BitVector &bits) const; 578 #endif 579 580 /// Builds the iteration lattices in a bottom-up traversal given the 581 /// remaining tensor (sub)expression and the next loop in the iteration 582 /// graph. Returns the identifier of the root set. 583 LatSetId buildLattices(ExprId e, LoopId i); 584 585 /// Builds a tensor expression from the given Linalg operation. 586 /// On success, returns the identifier of the root expression. 587 std::optional<ExprId> buildTensorExpFromLinalg(linalg::GenericOp op); 588 589 /// Rebuilds SSA format from a tensor expression. 590 Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, 591 Value v1) const; 592 593 private: 594 /// Private helpers. isValidTensorId(TensorId t)595 constexpr bool isValidTensorId(TensorId t) const { return t < numTensors; } isValidLoopId(LoopId i)596 constexpr bool isValidLoopId(LoopId i) const { 597 return i != detail::kInvalidId && i < numLoops; 598 } isValidLevel(TensorId t,Level lvl)599 bool isValidLevel(TensorId t, Level lvl) const { 600 assert(levelToDependentLoop[t].size() == lvlToLoop[t].size()); 601 return isValidTensorId(t) && lvl < lvlToLoop[t].size(); 602 } isValidExprId(ExprId e)603 bool isValidExprId(ExprId e) const { 604 return e != detail::kInvalidId && e < tensorExps.size(); 605 } isValidLatPointId(LatPointId p)606 bool isValidLatPointId(LatPointId p) const { 607 return p != detail::kInvalidId && p < latPoints.size(); 608 } isValidLatSetId(LatSetId s)609 bool isValidLatSetId(LatSetId s) const { 610 return s != detail::kInvalidId && s < latSets.size(); 611 } 612 bool maybeZero(ExprId e) const; isInvariant(ExprId e)613 bool isInvariant(ExprId e) const { 614 return exp(e).kind == TensorExp::Kind::kInvariant; 615 } 616 Type inferType(ExprId e, Value src) const; 617 618 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. 619 /// The boolean value returned indicates whether the result of the current 620 /// operation being built depends on any value that is loaded from a sparse 621 /// tensor. 622 std::pair<std::optional<ExprId>, bool> buildTensorExp(linalg::GenericOp op, 623 Value v); 624 625 /// Merger data structures. 626 const TensorId outTensor; 627 const TensorId syntheticTensor; 628 const unsigned numTensors; 629 const unsigned numLoops; 630 bool hasSparseOut; 631 632 // Below we use `std::vector` for things which have a priori fixed 633 // sizes, whereas we use `llvm::SmallVector` for things with variable 634 // size. Do beware that these two classes differ in the semantics of 635 // `operator[]`: `SmallVector` performs OOB checks, whereas `std::vector` 636 // does not. 637 638 /// Map that converts pair<TensorId, LoopId> to the corresponding lvl-type. 639 std::vector<std::vector<LevelType>> lvlTypes; 640 641 /// Map that converts pair<TensorId, LoopId> to the corresponding lvl. 642 std::vector<std::vector<std::optional<Level>>> loopToLvl; 643 644 /// Map that converts pair<TensorId, Level> to the corresponding LoopId. 645 std::vector<std::vector<std::optional<LoopId>>> lvlToLoop; 646 647 /// Map from a loop to its dependencies if any. 648 /// The dependencies of a loop is a set of (tensor, level) pairs. 649 /// It is currently only set for non-trivial index expressions. 650 /// E.g., A[i+j] => i and j will have dependencies {A0, lt(A0)} to indicate 651 /// that i and j are used in the non-trivial index expression on A0. 652 std::vector<std::vector<std::optional<LvlLTPair>>> loopToUnresolvedLvls; 653 654 /// The inverse map of ldxToDependencies from tensor level -> dependent loop 655 /// E.g., A[2i+j], we have A0 => {(2, i), (1, j)}, to indicate that A0 uses 656 /// both {i, j} to compute its indices and the coefficients on the loop id are 657 /// 2 and 1 respectively. 658 std::vector<std::vector<std::vector<LoopCoeffPair>>> levelToDependentLoop; 659 660 /// Map from a loop to the [tid, lvl] pair that defines the loop boundary. 661 std::vector<std::pair<TensorId, Level>> loopBounds; 662 663 llvm::SmallVector<TensorExp> tensorExps; 664 llvm::SmallVector<LatPoint> latPoints; 665 llvm::SmallVector<SmallVector<LatPointId>> latSets; 666 }; 667 668 } // namespace sparse_tensor 669 } // namespace mlir 670 671 #endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ 672