xref: /llvm-project/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h (revision 70e227a404e51f9248c7ad5d79953805b2afacb4)
1 //===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for dealing with iteration lattices.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
14 #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
15 
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
18 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
19 #include "mlir/IR/Value.h"
20 #include "llvm/ADT/BitVector.h"
21 
22 #include <optional>
23 
24 namespace mlir {
25 namespace sparse_tensor {
26 
27 namespace detail {
28 /// A constant serving as the canonically invalid identifier,
29 /// regardless of the identifier type.
30 static constexpr unsigned kInvalidId = -1u;
31 } // namespace detail
32 
33 /// Tensor identifiers, chosen to be the `BlockArgument::getArgNumber`
34 /// of the value passed to `Merger::buildTensorExp`.
35 using TensorId = unsigned;
36 
37 /// Loop identifiers.
38 using LoopId = unsigned;
39 
40 /// A compressed representation of `std::pair<TensorId, LoopId>`.
41 /// The compression scheme is such that this also serves as an index
42 /// into the bitvector stored in `LatPoint` (since that bitvector is
43 /// just the implementation for a set of `TensorLoopId` values).
44 using TensorLoopId = unsigned;
45 
46 /// `TensorExp` identifiers. These are allocated by `Merger::addExp`,
47 /// and serve as unique identifiers for the corresponding `TensorExp` object.
48 using ExprId = unsigned;
49 
50 /// `LatPoint` identifiers. These are allocated by `Merger::addLat`,
51 /// and serve as unique identifiers for the corresponding `LatPoint` object.
52 using LatPointId = unsigned;
53 
54 /// `LatSet` identifiers.  These are allocated by `Merger::addSet` (and
55 /// by other methods calling that one), and serve as unique identifiers
56 /// for the corresponding `SmallVector<LatPointId>` object.
57 using LatSetId = unsigned;
58 
59 /// A pair of level and its corresponding LevelType of a tensor.
60 using LvlLTPair = std::pair<Level, LevelType>;
61 
62 /// A pair of loop id and its coefficients. E.g., for affine expression in the
63 /// affine map `2 * d0`, loop id = 0, coefficient = 2.
64 using LoopCoeffPair = std::pair<LoopId, unsigned>;
65 
66 /// Tensor expression. Represents an MLIR expression in tensor index notation.
67 struct TensorExp final {
68   enum class Kind;
69 
70   /// Child subexpressions for non-leaf expressions.
71   struct Children final {
72     ExprId e0;
73     ExprId e1;
74   };
75 
76   /// The `x` parameter has different types depending on the value of the
77   /// `k` parameter.  The correspondences are:
78   /// * `kTensor`    -> `TensorId`
79   /// * `kInvariant` -> `kInvalidId`
80   /// * `kLoopVar`   -> `LoopId`
81   /// * else         -> `ExprId`
82   ///
83   /// The `y`, `v`, and `op` parameters either must or must not be
84   /// `kInvalidId`/`nullptr`, depending on the value of the `k` parameter;
85   /// however, they have uniform C++ types regardless of the value of `k`.
86   TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a);
87 
88   /// Tensor expression kind.
89   Kind kind;
90 
91   union {
92     /// `kTensor` expressions simply have a tensor identifier.
93     TensorId tensor;
94 
95     /// `kLoopVar` expressions simply have a loop identifier.
96     LoopId loop;
97 
98     /// All other expressions hold the `ExprId`s of their children.
99     Children children;
100   };
101 
102   /// Direct link to IR for an invariant or the destination value (to
103   /// infer destination type) of a cast operation During code generation,
104   /// this field may be used to cache "hoisted" loop invariant tensor loads.
105   Value val;
106 
107   /// Code blocks used by semirings. For the case of kUnary, kBinary, kReduce,
108   /// and kSelect, this holds the original operation with all regions. For
109   /// kBinaryBranch, this holds the YieldOp for the left or right half
110   /// to be merged into a nested scf loop.
111   ///
112   /// Or the actual operation that we can not sparsify but having all dense
113   /// operands for kDenseOp.
114   Operation *op;
115 
116   /// An optional attribute that is required to determine the semantics of the
117   /// operations. E.g., CmpPredicateAttr for CmpI/CmpF operations.
118   Attribute attr;
119 };
120 
121 /// Tensor expression kind.
122 ///
123 /// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
124 /// That is, its argument is a `LoopId` identifying the loop-variable
125 /// in question, and its value will be the current iteration's value.
126 /// The `kSynZero` leaf kind is for representing a synthetic zero value,
127 /// which can be introduced when sparsifying operations like `arith::cmp`
128 /// to generate `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
129 enum class TensorExp::Kind {
130   // Leaf.
131   kTensor = 0,
132   kSynZero,
133   kInvariant,
134   kLoopVar,
135   // Unary operations.
136   kAbsF,
137   kAbsC,
138   kAbsI,
139   kCeilF,
140   kFloorF,
141   kSqrtF,
142   kSqrtC,
143   kExpm1F,
144   kExpm1C,
145   kLog1pF,
146   kLog1pC,
147   kRelu,
148   kSinF,
149   kSinC,
150   kTanhF,
151   kTanhC,
152   kNegF,
153   kNegC,
154   kNegI,
155   kTruncF,
156   kExtF,
157   kCastFS, // signed
158   kCastFU, // unsigned
159   kCastSF, // signed
160   kCastUF, // unsigned
161   kCastS,  // signed
162   kCastU,  // unsigned
163   kCastIdx,
164   kTruncI,
165   kCIm, // complex.im
166   kCRe, // complex.re
167   kBitCast,
168   kBinaryBranch, // semiring unary branch created from a binary op
169   kUnary,        // semiring unary op
170   kSelect,       // custom selection criteria
171   // Binary operations.
172   kMulF,
173   kMulC,
174   kMulI,
175   kDivF,
176   kDivC, // complex
177   kDivS, // signed
178   kDivU, // unsigned
179   kAddF,
180   kAddC,
181   kAddI,
182   kSubF,
183   kSubC,
184   kSubI,
185   kAndI,
186   kOrI,
187   kXorI,
188   kCmpI,
189   kCmpF,
190   kShrS, // signed
191   kShrU, // unsigned
192   kShlI,
193   kBinary,  // semiring binary op
194   kReduce,  // semiring reduction op
195   kDenseOp, // special category of operations requiring all dense operands
196 };
197 
198 /// Lattice point.  Each lattice point consists of a formal conjunction
199 /// of `TensorLoopId`s, together with the identifier of the corresponding
200 /// tensor expression.  The formal conjunction is represented as a set of
201 /// `TensorLoopId`, where that set is implemented as a `BitVector`.
202 struct LatPoint final {
203   /// Construct a lattice point with the empty set of `TensorLoopId`s.
LatPointfinal204   LatPoint(unsigned size, ExprId e) : bits(size, false), exp(e) {}
205 
206   /// Construct a lattice point from the given set of `TensorLoopId`s.
LatPointfinal207   LatPoint(const BitVector &bits, ExprId e) : bits(bits), exp(e) {}
208 
209   /// Conjunction of all `TensorLoopId`s involved in the tensor expression.
210   BitVector bits;
211 
212   /// Simplified conjunction of `TensorLoopId` as bitvector.  This
213   /// represents a simplified condition under which this tensor expression
214   /// must execute. Pre-computed during codegen to avoid repeated eval.
215   BitVector simple;
216 
217   /// Identifier of the tensor expression.
218   ExprId exp;
219 };
220 
221 /// A class to handle all iteration lattice operations. This class abstracts
222 /// away from some implementation details of storing iteration lattices and
223 /// tensor expressions. This allows for fine-tuning performance characteristics
224 /// independently from the basic algorithm if bottlenecks are identified.
225 class Merger {
226 public:
227   /// Constructs a merger for the given number of tensors and loops. The user
228   /// supplies the number of tensors involved in the kernel, with the last
229   /// tensor in this set denoting the output tensor. The merger adds an
230   /// additional synthetic tensor at the end of this set to represent all
231   /// invariant expressions in the kernel.
232   ///
233   /// The maxLvlRank specifies the max level rank of all inputs/output tensors.
234   /// It is used to pre-allocate sufficient memory for internal storage.
235   Merger(unsigned numInputOutputTensors, unsigned numLoops,
236          unsigned maxLvlRank);
237 
238   //
239   // Constructing valid tensor and loop identifiers.
240   //
241 
242   /// Safely converts the argument to a tensor identifier.
makeTensorId(unsigned t)243   constexpr TensorId makeTensorId(unsigned t) const {
244     assert(isValidTensorId(t));
245     return t;
246   }
247 
248   /// Safely converts the argument to a loop identifier.
makeLoopId(unsigned i)249   constexpr LoopId makeLoopId(unsigned i) const {
250     assert(isValidLoopId(i));
251     return i;
252   }
253 
254   /// Safely converts the arguments to a pair of (tensor,loop) identifiers.
makeTensorLoopId(unsigned t,unsigned i)255   constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const {
256     assert(isValidTensorId(t) && isValidLoopId(i));
257     return numTensors * i + t;
258   }
259 
260   //
261   // Allocating new expressions, points, and sets.
262   //
263 
264   /// Constructs a new tensor expression, and returns its identifier.
265   ExprId addTensorExp(TensorId t);
266   /// Constructs a new loop-variable expression, and returns its identifier.
267   ExprId addLoopVarExp(LoopId i);
268   /// Constructs a new invariant expression, and returns its identifier.
269   ExprId addInvariantExp(Value v);
270   /// Constructs a new synthetic zero expression.
271   ExprId addSynZeroExp();
272   /// Constructs a new unary or binary expression, and returns its identifier.
273   ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1 = detail::kInvalidId,
274                 Operation *op = nullptr, Attribute attr = nullptr);
275   /// Constructs a new sesquinary expression, and returns its identifier.
276   /// Currently no sesquinary `Kind` allows specifying the `op`, but we
277   /// allow it anyways because `mapSet` is designed to allow it.
278   ExprId addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op = nullptr,
279                 Attribute attr = nullptr);
280 
281   /// Constructs a new iteration lattice point, and returns its identifier.
282   LatPointId addLat(TensorId t, LoopId i, ExprId e);
283   LatPointId addLat(const BitVector &bits, ExprId e);
284 
285   /// Constructs a new (initially empty) set, and returns its identifier.
286   LatSetId addSet();
287 
288   /// Computes a single conjunction of two lattice points by taking the "union"
289   /// of `LoopId` (effectively constructing a larger "intersection" of those
290   /// loops) with a newly constructed tensor (sub)expression of given kind.
291   /// Returns the identifier of the new lattice point.
292   LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1,
293                      Operation *op = nullptr);
294 
295   /// Conjunctive merge of two lattice sets: `(s0 /\_op s1)`.
296   /// Returns the identifier of the new set.
297   LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op = nullptr);
298 
299   /// Disjunctive merge of two lattice sets: `(s0 /\_op s1, s0, s1)`.
300   /// Returns the identifier of the new set.
301   LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op = nullptr);
302 
303   /// Disjunctive merge of two lattice sets and also set one of the operand to
304   /// zero: `(s0 /\_op s1 (e0 op e1), s0 (0 op e0), s1 (e1 op 0))`.
305   /// Returns the identifier of the new set.
306   LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1);
307 
308   /// Disjunctive merge of two lattice sets with custom handling of the
309   /// overlap, left, and right regions.  Any region may be left missing
310   /// in the output.  Returns the identifier of the new set.
311   LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
312                     bool includeLeft, TensorExp::Kind ltrans, Operation *opleft,
313                     bool includeRight, TensorExp::Kind rtrans,
314                     Operation *opright);
315 
316   /// Maps the unary operator over the lattice set of the operand, i.e. each
317   /// lattice point on an expression E is simply copied over, but with OP E
318   /// as new expression. Returns the identifier of the new set.
319   LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v = Value(),
320                   Operation *op = nullptr, Attribute attr = nullptr);
321 
322   /// Maps the binary operator to the same operation but with one of its operand
323   /// set to zero, i.e. each lattice point on an expression E is simply copied
324   /// over, but with `OP 0 E` (if lhsZero == true) or `OP E 0` (if lhsZero ==
325   /// false) as new expression. Returns the identifier of the new set.
326   LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero);
327 
328   /// Optimizes the iteration lattice points in the given set. This
329   /// method should be called right before code generation to avoid
330   /// generating redundant loops and conditions.
331   LatSetId optimizeSet(LatSetId s);
332 
333   /// Simplifies the conditions in a conjunction of a given lattice point
334   /// within the given set using just two basic rules:
335   /// (1) multiple dense conditions are reduced to single dense, and
336   /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
337   BitVector simplifyCond(LatSetId s, LatPointId p);
338 
339   /// Returns true if p0 > p1.
340   bool latGT(LatPointId p0, LatPointId p1) const;
341 
342   /// Returns true if p0 and p1 only differ in dense.
343   bool onlyDenseDiff(LatPointId p0, LatPointId p1) const;
344 
345   /// Gets the tensor-identifier of the `TensorLoopId`.
tensor(TensorLoopId b)346   constexpr TensorId tensor(TensorLoopId b) const { return b % numTensors; }
347   /// Gets the loop-identifier of the `TensorLoopId`.
loop(TensorLoopId b)348   constexpr LoopId loop(TensorLoopId b) const { return b / numTensors; }
349 
350   /// Gets the total number of tensors (including the output-tensor and
351   /// synthetic-tensor).
getNumTensors()352   constexpr unsigned getNumTensors() const { return numTensors; }
353 
354   /// Gets the total number of loops (native loops + filter loops).
getNumLoops()355   constexpr unsigned getNumLoops() const { return numLoops; }
356 
357   /// Returns true if `b` is the `i`th loop of the output tensor.
isOutTensor(TensorLoopId b,LoopId i)358   constexpr bool isOutTensor(TensorLoopId b, LoopId i) const {
359     return b == makeTensorLoopId(outTensor, i);
360   }
361 
362   /// Gets the output tensor's identifier.
getOutTensorID()363   constexpr TensorId getOutTensorID() const { return outTensor; }
364 
365   /// Gets the synthetic tensor's identifier (used for all invariant
366   /// tensor expressions).
getSynTensorID()367   constexpr TensorId getSynTensorID() const { return syntheticTensor; }
368 
369   /// Returns true if the expression is `(kTensor t)`.
expIsTensor(ExprId e,TensorId t)370   bool expIsTensor(ExprId e, TensorId t) const {
371     const auto &expr = exp(e);
372     return expr.kind == TensorExp::Kind::kTensor && expr.tensor == t;
373   }
374 
375   /// Returns true if the expression contains the tensor as an operand.
376   bool expContainsTensor(ExprId e, TensorId t) const;
377 
378   /// Returns true if the expression contains a negation on output tensor.
379   /// I.e., `- outTensor` or `exp - outputTensor`
380   /// NOTE: this is an trivial tests in that it does not handle recursive
381   /// negation, i.e., it returns true when the expression is `-(-tensor)`.
382   bool hasNegateOnOut(ExprId e) const;
383 
384   /// Returns true if given tensor iterates *only* in the given tensor
385   /// expression. For the output tensor, this defines a "simply dynamic"
386   /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for
387   /// sparse vector a.
388   bool isSingleCondition(TensorId t, ExprId e) const;
389 
390   /// Returns true if any `TensorLoopId` in the bitvector corresponds
391   /// to sparse level-type.
392   bool hasAnySparse(const BitVector &bits) const;
393 
394   /// Returns true if bits contains a dependent index reduction condition on
395   /// sparse levels.
396   bool hasSparseIdxReduction(const BitVector &bits) const;
397 
398   /// Gets the level-type of the `t`th tensor on `i`th loop.
getLvlType(TensorId t,LoopId i)399   LevelType getLvlType(TensorId t, LoopId i) const {
400     assert(isValidTensorId(t) && isValidLoopId(i));
401     return lvlTypes[t][i];
402   }
403 
404   /// Gets the level-type of the TensorLoopId.
getLvlType(TensorLoopId b)405   LevelType getLvlType(TensorLoopId b) const {
406     return getLvlType(tensor(b), loop(b));
407   }
408 
409   /// Gets the loop identifier for the `lvl`th level of the `t`th tensor.
getLoopId(TensorId t,Level lvl)410   std::optional<LoopId> getLoopId(TensorId t, Level lvl) const {
411     assert(isValidLevel(t, lvl));
412     return lvlToLoop[t][lvl];
413   }
414 
415   /// Gets the level number of the the `t`th tensor on `i`th loop.
getLvl(TensorId t,LoopId i)416   std::optional<Level> getLvl(TensorId t, LoopId i) const {
417     assert(isValidTensorId(t) && isValidLoopId(i));
418     return loopToLvl[t][i];
419   }
getLvl(TensorLoopId b)420   std::optional<Level> getLvl(TensorLoopId b) const {
421     return getLvl(tensor(b), loop(b));
422   }
423 
424   /// Sets the level number and level-type of the `t`th tensor on
425   /// `i`th loop.
setLevelAndType(TensorId t,LoopId i,Level lvl,LevelType lt)426   void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt) {
427     assert(isValidLevel(t, lvl) && isValidLoopId(i) && isValidLT(lt));
428     lvlTypes[t][i] = lt;
429     loopToLvl[t][i] = lvl;
430     lvlToLoop[t][lvl] = i;
431     // TODO: favor a constant loop bound when there are multiple choices.
432     loopBounds[i] = std::make_pair(t, lvl);
433   }
434 
435   using ForeachTensorLoopIdCallback = function_ref<void(
436       TensorLoopId, TensorId, std::optional<Level>, LevelType, bool)>;
437 
438   /// Iterates over a set of `TensorLoopId`s, invoking the callback
439   /// for each `TensorLoopId` and passing it the corresponding tensor
440   /// identifier, level, and level-type, following with a boolean value
441   /// indicating whether it is a dependent index reduction loop condition.
foreachTensorLoopId(LatPointId p,ForeachTensorLoopIdCallback callback)442   void foreachTensorLoopId(LatPointId p,
443                            ForeachTensorLoopIdCallback callback) const {
444     // TODO: the default ought to be simple=true; but we'll need to make
445     // sure to update all the tests to make sure they do the right thing.
446     foreachTensorLoopId(p, /*simple=*/false, callback);
447   }
foreachTensorLoopId(LatPointId p,bool simple,ForeachTensorLoopIdCallback callback)448   void foreachTensorLoopId(LatPointId p, bool simple,
449                            ForeachTensorLoopIdCallback callback) const {
450     const auto &point = lat(p);
451     const auto &bits = simple ? point.simple : point.bits;
452     for (const TensorLoopId b : bits.set_bits()) {
453       const TensorId t = tensor(b);
454       const auto optLvl = getLvl(b);
455       const auto lvlTp = getLvlType(b);
456       if (isLvlWithNonTrivialIdxExp(b)) {
457         // This must be an undefined level.
458         assert(!optLvl.has_value());
459         // Slice the tid along the dependent level to iterate current loop.
460         callback(b, t, getLoopDependentLevel(b), lvlTp,
461                  /*isIdxReduc=*/true);
462       } else {
463         callback(b, t, optLvl, lvlTp, /*isIdxReduc=*/false);
464       }
465     }
466   }
467 
468   /// Sets whether the output tensor is sparse or not.
setHasSparseOut(bool s)469   void setHasSparseOut(bool s) { hasSparseOut = s; }
470 
471   /// Establishes the two-way map that i <-> <t, lvl, lt>.
setLoopDependentTensorLevel(LoopId i,TensorId t,Level lvl,LevelType lt,unsigned coefficient)472   void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl,
473                                    LevelType lt, unsigned coefficient) {
474     assert(isValidLoopId(i) && isValidLevel(t, lvl));
475     assert(!loopToUnresolvedLvls[i][t].has_value()); // must be the first def
476     loopToUnresolvedLvls[i][t] = std::make_pair(lvl, lt);
477     levelToDependentLoop[t][lvl].emplace_back(i, coefficient);
478   }
479 
480   /// Whether the loop has dependent slice.
hasDependentLvl(LoopId i,TensorId t)481   bool hasDependentLvl(LoopId i, TensorId t) {
482     assert(isValidTensorId(t) && isValidLoopId(i));
483     return loopToUnresolvedLvls[i][t].has_value();
484   }
485 
486   /// Returns the list of loop indices which appear in the non-trivial index
487   /// expression on t_l, e.g., A[i+j] => {i, j}
getDependentLoops(TensorId t,Level lvl)488   std::vector<LoopCoeffPair> &getDependentLoops(TensorId t, Level lvl) {
489     assert(isValidLevel(t, lvl));
490     return levelToDependentLoop[t][lvl];
491   }
492 
493   /// Returns the defining [tid, lvl] for the loop.
getLoopDefiningLvl(LoopId i)494   std::pair<TensorId, Level> getLoopDefiningLvl(LoopId i) const {
495     assert(isValidLoopId(i));
496     return loopBounds[i];
497   }
498 
499   /// Checks whether the TensorLoopId represents a tensor level contains
500   /// non-trivial index expression.
isLvlWithNonTrivialIdxExp(TensorLoopId b)501   bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const {
502     const TensorId t = tensor(b);
503     const LoopId i = loop(b);
504     assert(isValidTensorId(t) && isValidLoopId(i));
505     return loopToUnresolvedLvls[i][t].has_value();
506   }
507 
508   /// Checks whether the TensorLoopId represents a sparse tensor level contains
509   /// non-trivial index expression.
isSparseLvlWithNonTrivialIdxExp(TensorLoopId b)510   bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
511     if (isLvlWithNonTrivialIdxExp(b)) {
512       auto lt = getLoopDependentLevelType(b);
513       return lt.hasSparseSemantic();
514     }
515     return false;
516   }
517 
getLoopDependentLevel(TensorLoopId b)518   Level getLoopDependentLevel(TensorLoopId b) const {
519     assert(isLvlWithNonTrivialIdxExp(b));
520     return loopToUnresolvedLvls[loop(b)][tensor(b)]->first;
521   }
522 
getLoopDependentLevelType(TensorLoopId b)523   LevelType getLoopDependentLevelType(TensorLoopId b) const {
524     assert(isLvlWithNonTrivialIdxExp(b));
525     return loopToUnresolvedLvls[loop(b)][tensor(b)]->second;
526   }
527 
528   /// Convenience getters to immediately access the stored nodes.
529   /// These methods return `const&` because the underlying objects must
530   /// not be mutated by client code.  The only exception is for mutating
531   /// the value associated with an expression, for which there are
532   /// dedicated methods below.
533   ///
534   /// NOTE: It is inadvisable to keep the reference alive for a long
535   /// time (e.g., as in `TensorExpr &te = merger.exp(e)`), since insertions
536   /// into the merger can cause data movement which will invalidate the
537   /// underlying memory address.  This isn't just a problem with the `&`
538   /// references, but also applies to the `ArrayRef`.  In particular,
539   /// using `for (LatPointId p : merger.set(s))` will run into the same
540   /// dangling-reference problems if the loop body inserts new sets.
exp(ExprId e)541   const TensorExp &exp(ExprId e) const {
542     assert(isValidExprId(e));
543     return tensorExps[e];
544   }
lat(LatPointId p)545   const LatPoint &lat(LatPointId p) const {
546     assert(isValidLatPointId(p));
547     return latPoints[p];
548   }
set(LatSetId s)549   ArrayRef<LatPointId> set(LatSetId s) const {
550     assert(isValidLatSetId(s));
551     return latSets[s];
552   }
553 
554   /// Checks whether the given expression has an associated value.
hasExprValue(ExprId e)555   bool hasExprValue(ExprId e) const { return static_cast<bool>(exp(e).val); }
556 
557   /// Sets the expression to have the associated value. Asserts that the new
558   /// value is defined, and that the expression does not already have a value.
setExprValue(ExprId e,Value v)559   void setExprValue(ExprId e, Value v) {
560     assert(!exp(e).val && "Expression already has an associated value");
561     assert(v && "Trying to assign an undefined value");
562     tensorExps[e].val = v;
563   }
564 
565   /// Clears the value associated with the expression. Asserts that the
566   /// expression does indeed have an associated value before clearing it.
clearExprValue(ExprId e)567   void clearExprValue(ExprId e) {
568     assert(exp(e).val && "Expression does not have an associated value");
569     tensorExps[e].val = Value();
570   }
571 
572 #ifndef NDEBUG
573   /// Print methods (for debugging).
574   void dumpExp(ExprId e) const;
575   void dumpLat(LatPointId p) const;
576   void dumpSet(LatSetId s) const;
577   void dumpBits(const BitVector &bits) const;
578 #endif
579 
580   /// Builds the iteration lattices in a bottom-up traversal given the
581   /// remaining tensor (sub)expression and the next loop in the iteration
582   /// graph.  Returns the identifier of the root set.
583   LatSetId buildLattices(ExprId e, LoopId i);
584 
585   /// Builds a tensor expression from the given Linalg operation.
586   /// On success, returns the identifier of the root expression.
587   std::optional<ExprId> buildTensorExpFromLinalg(linalg::GenericOp op);
588 
589   /// Rebuilds SSA format from a tensor expression.
590   Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
591                  Value v1) const;
592 
593 private:
594   /// Private helpers.
isValidTensorId(TensorId t)595   constexpr bool isValidTensorId(TensorId t) const { return t < numTensors; }
isValidLoopId(LoopId i)596   constexpr bool isValidLoopId(LoopId i) const {
597     return i != detail::kInvalidId && i < numLoops;
598   }
isValidLevel(TensorId t,Level lvl)599   bool isValidLevel(TensorId t, Level lvl) const {
600     assert(levelToDependentLoop[t].size() == lvlToLoop[t].size());
601     return isValidTensorId(t) && lvl < lvlToLoop[t].size();
602   }
isValidExprId(ExprId e)603   bool isValidExprId(ExprId e) const {
604     return e != detail::kInvalidId && e < tensorExps.size();
605   }
isValidLatPointId(LatPointId p)606   bool isValidLatPointId(LatPointId p) const {
607     return p != detail::kInvalidId && p < latPoints.size();
608   }
isValidLatSetId(LatSetId s)609   bool isValidLatSetId(LatSetId s) const {
610     return s != detail::kInvalidId && s < latSets.size();
611   }
612   bool maybeZero(ExprId e) const;
isInvariant(ExprId e)613   bool isInvariant(ExprId e) const {
614     return exp(e).kind == TensorExp::Kind::kInvariant;
615   }
616   Type inferType(ExprId e, Value src) const;
617 
618   /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
619   /// The boolean value returned indicates whether the result of the current
620   /// operation being built depends on any value that is loaded from a sparse
621   /// tensor.
622   std::pair<std::optional<ExprId>, bool> buildTensorExp(linalg::GenericOp op,
623                                                         Value v);
624 
625   /// Merger data structures.
626   const TensorId outTensor;
627   const TensorId syntheticTensor;
628   const unsigned numTensors;
629   const unsigned numLoops;
630   bool hasSparseOut;
631 
632   // Below we use `std::vector` for things which have a priori fixed
633   // sizes, whereas we use `llvm::SmallVector` for things with variable
634   // size.  Do beware that these two classes differ in the semantics of
635   // `operator[]`: `SmallVector` performs OOB checks, whereas `std::vector`
636   // does not.
637 
638   /// Map that converts pair<TensorId, LoopId> to the corresponding lvl-type.
639   std::vector<std::vector<LevelType>> lvlTypes;
640 
641   /// Map that converts pair<TensorId, LoopId> to the corresponding lvl.
642   std::vector<std::vector<std::optional<Level>>> loopToLvl;
643 
644   /// Map that converts pair<TensorId, Level> to the corresponding LoopId.
645   std::vector<std::vector<std::optional<LoopId>>> lvlToLoop;
646 
647   /// Map from a loop to its dependencies if any.
648   /// The dependencies of a loop is a set of (tensor, level) pairs.
649   /// It is currently only set for non-trivial index expressions.
650   /// E.g., A[i+j] => i and j will have dependencies {A0, lt(A0)} to indicate
651   /// that i and j are used in the non-trivial index expression on A0.
652   std::vector<std::vector<std::optional<LvlLTPair>>> loopToUnresolvedLvls;
653 
654   /// The inverse map of ldxToDependencies from tensor level -> dependent loop
655   /// E.g., A[2i+j], we have A0 => {(2, i), (1, j)}, to indicate that A0 uses
656   /// both {i, j} to compute its indices and the coefficients on the loop id are
657   /// 2 and 1 respectively.
658   std::vector<std::vector<std::vector<LoopCoeffPair>>> levelToDependentLoop;
659 
660   /// Map from a loop to the [tid, lvl] pair that defines the loop boundary.
661   std::vector<std::pair<TensorId, Level>> loopBounds;
662 
663   llvm::SmallVector<TensorExp> tensorExps;
664   llvm::SmallVector<LatPoint> latPoints;
665   llvm::SmallVector<SmallVector<LatPointId>> latSets;
666 };
667 
668 } // namespace sparse_tensor
669 } // namespace mlir
670 
671 #endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
672