xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (revision 1944c4f76b47c0b86c91845987baca24fd4775f8)
1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 #include <optional>
18 
19 namespace mlir {
20 namespace sparse_tensor {
21 
22 enum class ExpArity {
23   kNullary,
24   kUnary,
25   kBinary,
26 };
27 
28 static ExpArity getExpArity(TensorExp::Kind k) {
29   switch (k) {
30   // Leaf.
31   case TensorExp::Kind::kTensor:
32   case TensorExp::Kind::kInvariant:
33   case TensorExp::Kind::kLoopVar:
34   case TensorExp::Kind::kSynZero:
35     return ExpArity::kNullary;
36   case TensorExp::Kind::kAbsF:
37   case TensorExp::Kind::kAbsC:
38   case TensorExp::Kind::kAbsI:
39   case TensorExp::Kind::kCeilF:
40   case TensorExp::Kind::kFloorF:
41   case TensorExp::Kind::kSqrtF:
42   case TensorExp::Kind::kSqrtC:
43   case TensorExp::Kind::kExpm1F:
44   case TensorExp::Kind::kExpm1C:
45   case TensorExp::Kind::kLog1pF:
46   case TensorExp::Kind::kLog1pC:
47   case TensorExp::Kind::kSinF:
48   case TensorExp::Kind::kSinC:
49   case TensorExp::Kind::kTanhF:
50   case TensorExp::Kind::kTanhC:
51   case TensorExp::Kind::kTruncF:
52   case TensorExp::Kind::kExtF:
53   case TensorExp::Kind::kCastFS:
54   case TensorExp::Kind::kCastFU:
55   case TensorExp::Kind::kCastSF:
56   case TensorExp::Kind::kCastUF:
57   case TensorExp::Kind::kCastS:
58   case TensorExp::Kind::kCastU:
59   case TensorExp::Kind::kCastIdx:
60   case TensorExp::Kind::kTruncI:
61   case TensorExp::Kind::kCIm:
62   case TensorExp::Kind::kCRe:
63   case TensorExp::Kind::kBitCast:
64   case TensorExp::Kind::kBinaryBranch:
65   case TensorExp::Kind::kUnary:
66   case TensorExp::Kind::kSelect:
67   case TensorExp::Kind::kNegF:
68   case TensorExp::Kind::kNegC:
69   case TensorExp::Kind::kNegI:
70     return ExpArity::kUnary;
71   // Binary operations.
72   case TensorExp::Kind::kDivF:
73   case TensorExp::Kind::kDivC:
74   case TensorExp::Kind::kDivS:
75   case TensorExp::Kind::kDivU:
76   case TensorExp::Kind::kShrS:
77   case TensorExp::Kind::kShrU:
78   case TensorExp::Kind::kShlI:
79   case TensorExp::Kind::kMulF:
80   case TensorExp::Kind::kMulC:
81   case TensorExp::Kind::kMulI:
82   case TensorExp::Kind::kAndI:
83   case TensorExp::Kind::kAddF:
84   case TensorExp::Kind::kAddC:
85   case TensorExp::Kind::kAddI:
86   case TensorExp::Kind::kOrI:
87   case TensorExp::Kind::kXorI:
88   case TensorExp::Kind::kBinary:
89   case TensorExp::Kind::kReduce:
90   case TensorExp::Kind::kSubF:
91   case TensorExp::Kind::kSubC:
92   case TensorExp::Kind::kSubI:
93   case TensorExp::Kind::kCmpF:
94   case TensorExp::Kind::kCmpI:
95   case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands
96     return ExpArity::kBinary;
97   }
98   llvm_unreachable("unexpected kind");
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // Constructors.
103 //===----------------------------------------------------------------------===//
104 
105 TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
106                      Operation *o, Attribute a)
107     : kind(k), val(v), op(o) {
108   switch (kind) {
109   // Leaf.
110   case TensorExp::Kind::kTensor:
111     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
112     tensor = x;
113     return;
114   case TensorExp::Kind::kSynZero:
115     assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o);
116     return;
117   case TensorExp::Kind::kInvariant:
118     assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o);
119     return;
120   case TensorExp::Kind::kLoopVar:
121     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
122     loop = x;
123     return;
124   // Unary operations.
125   case TensorExp::Kind::kAbsF:
126   case TensorExp::Kind::kAbsC:
127   case TensorExp::Kind::kAbsI:
128   case TensorExp::Kind::kCeilF:
129   case TensorExp::Kind::kFloorF:
130   case TensorExp::Kind::kSqrtF:
131   case TensorExp::Kind::kSqrtC:
132   case TensorExp::Kind::kExpm1F:
133   case TensorExp::Kind::kExpm1C:
134   case TensorExp::Kind::kLog1pF:
135   case TensorExp::Kind::kLog1pC:
136   case TensorExp::Kind::kSinF:
137   case TensorExp::Kind::kSinC:
138   case TensorExp::Kind::kTanhF:
139   case TensorExp::Kind::kTanhC:
140   case TensorExp::Kind::kNegF:
141   case TensorExp::Kind::kNegC:
142   case TensorExp::Kind::kNegI:
143   case TensorExp::Kind::kCIm:
144   case TensorExp::Kind::kCRe:
145     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
146     children.e0 = x;
147     children.e1 = y;
148     return;
149   case TensorExp::Kind::kTruncF:
150   case TensorExp::Kind::kExtF:
151   case TensorExp::Kind::kCastFS:
152   case TensorExp::Kind::kCastFU:
153   case TensorExp::Kind::kCastSF:
154   case TensorExp::Kind::kCastUF:
155   case TensorExp::Kind::kCastS:
156   case TensorExp::Kind::kCastU:
157   case TensorExp::Kind::kCastIdx:
158   case TensorExp::Kind::kTruncI:
159   case TensorExp::Kind::kBitCast:
160     assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o);
161     children.e0 = x;
162     children.e1 = y;
163     return;
164   case TensorExp::Kind::kBinaryBranch:
165   case TensorExp::Kind::kSelect:
166     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o);
167     children.e0 = x;
168     children.e1 = y;
169     return;
170   case TensorExp::Kind::kUnary:
171     // No assertion on y can be made, as the branching paths involve both
172     // a unary (`mapSet`) and binary (`disjSet`) pathway.
173     assert(x != detail::kInvalidId && !v && o);
174     children.e0 = x;
175     children.e1 = y;
176     return;
177   // Binary operations.
178   case TensorExp::Kind::kMulF:
179   case TensorExp::Kind::kMulC:
180   case TensorExp::Kind::kMulI:
181   case TensorExp::Kind::kDivF:
182   case TensorExp::Kind::kDivC:
183   case TensorExp::Kind::kDivS:
184   case TensorExp::Kind::kDivU:
185   case TensorExp::Kind::kAddF:
186   case TensorExp::Kind::kAddC:
187   case TensorExp::Kind::kAddI:
188   case TensorExp::Kind::kSubF:
189   case TensorExp::Kind::kSubC:
190   case TensorExp::Kind::kSubI:
191   case TensorExp::Kind::kAndI:
192   case TensorExp::Kind::kOrI:
193   case TensorExp::Kind::kXorI:
194   case TensorExp::Kind::kShrS:
195   case TensorExp::Kind::kShrU:
196   case TensorExp::Kind::kShlI:
197     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
198     children.e0 = x;
199     children.e1 = y;
200     return;
201   case TensorExp::Kind::kCmpF:
202   case TensorExp::Kind::kCmpI:
203     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
204     attr = a;
205     children.e0 = x;
206     children.e1 = y;
207     return;
208   case TensorExp::Kind::kBinary:
209   case TensorExp::Kind::kReduce:
210     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o);
211     children.e0 = x;
212     children.e1 = y;
213     return;
214   case TensorExp::Kind::kDenseOp:
215     assert(x != detail::kInvalidId && !v && o);
216     children.e0 = x;
217     children.e1 = y;
218     return;
219   }
220   llvm_unreachable("unexpected kind");
221 }
222 
223 Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
224                unsigned maxLvlRank)
225     : outTensor(numInputOutputTensors - 1),
226       syntheticTensor(numInputOutputTensors),
227       numTensors(numInputOutputTensors + 1), numLoops(numLoops),
228       hasSparseOut(false),
229       lvlTypes(numTensors, std::vector<LevelType>(numLoops, LevelType::Undef)),
230       loopToLvl(numTensors,
231                 std::vector<std::optional<Level>>(numLoops, std::nullopt)),
232       lvlToLoop(numTensors,
233                 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
234       loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>(
235                                          numTensors, std::nullopt)),
236       levelToDependentLoop(numTensors,
237                            std::vector<std::vector<LoopCoeffPair>>(
238                                maxLvlRank, std::vector<LoopCoeffPair>())),
239       loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
240 
241 //===----------------------------------------------------------------------===//
242 // Lattice methods.
243 //===----------------------------------------------------------------------===//
244 
245 ExprId Merger::addTensorExp(TensorId t) {
246   assert(isValidTensorId(t));
247   const ExprId eNew(tensorExps.size());
248   tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId,
249                           Value(), nullptr, nullptr);
250   return eNew;
251 }
252 
253 ExprId Merger::addLoopVarExp(LoopId i) {
254   assert(isValidLoopId(i));
255   const ExprId eNew(tensorExps.size());
256   tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId,
257                           Value(), nullptr, nullptr);
258   return eNew;
259 }
260 
261 ExprId Merger::addInvariantExp(Value v) {
262   const ExprId eNew(tensorExps.size());
263   tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId,
264                           detail::kInvalidId, v, nullptr, nullptr);
265   return eNew;
266 }
267 
268 ExprId Merger::addSynZeroExp() {
269   const ExprId eNew(tensorExps.size());
270   tensorExps.emplace_back(TensorExp::Kind::kSynZero, detail::kInvalidId,
271                           detail::kInvalidId, Value(), nullptr, nullptr);
272   return eNew;
273 }
274 
275 ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op,
276                       Attribute attr) {
277   assert(k > TensorExp::Kind::kLoopVar);
278   const ExprId eNew(tensorExps.size());
279   tensorExps.emplace_back(k, e0, e1, Value(), op, attr);
280   return eNew;
281 }
282 
283 ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op,
284                       Attribute attr) {
285   assert(k > TensorExp::Kind::kLoopVar);
286   const ExprId eNew(tensorExps.size());
287   tensorExps.emplace_back(k, e, detail::kInvalidId, v, op, attr);
288   return eNew;
289 }
290 
291 LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {
292   const LatPointId pNew(latPoints.size());
293   const unsigned size = numLoops * numTensors;
294   const TensorLoopId b = makeTensorLoopId(t, i);
295   latPoints.emplace_back(size, e);
296   latPoints[pNew].bits.set(b);
297   return pNew;
298 }
299 
300 LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
301   assert(bits.size() == numLoops * numTensors);
302   const LatPointId pNew(latPoints.size());
303   latPoints.emplace_back(bits, e);
304   return pNew;
305 }
306 
307 LatSetId Merger::addSet() {
308   const LatSetId sNew(latSets.size());
309   latSets.emplace_back();
310   return sNew;
311 }
312 
313 LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1,
314                            Operation *op) {
315   TensorExp::Kind kind = exp(e).kind;
316   Attribute attr = exp(e).attr;
317   const LatPointId pNew(latPoints.size());
318   const auto &point0 = lat(p0);
319   const auto &point1 = lat(p1);
320   BitVector bits(point0.bits);
321   bits |= point1.bits;
322   const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr);
323   latPoints.emplace_back(bits, ne);
324   return pNew;
325 }
326 
327 LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
328   const LatSetId sNew = addSet();
329   auto &setNew = latSets[sNew];
330   for (const LatPointId p0 : set(s0))
331     for (const LatPointId p1 : set(s1))
332       setNew.push_back(conjLat(e, p0, p1, op));
333   return sNew;
334 }
335 
336 LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
337   const LatSetId sNew = conjSet(e, s0, s1, op);
338   TensorExp::Kind kind = exp(e).kind;
339 
340   // Followed by all in s0.
341   latSets[sNew].append(latSets[s0]);
342   // Map binary 0-y to unary -y.
343   // TODO: move this if-else logic into buildLattices
344   if (kind == TensorExp::Kind::kSubF)
345     s1 = mapSet(TensorExp::Kind::kNegF, s1);
346   else if (kind == TensorExp::Kind::kSubC)
347     s1 = mapSet(TensorExp::Kind::kNegC, s1);
348   else if (kind == TensorExp::Kind::kSubI)
349     s1 = mapSet(TensorExp::Kind::kNegI, s1);
350   // Followed by all in s1.
351   latSets[sNew].append(latSets[s1]);
352   return sNew;
353 }
354 
355 LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) {
356   assert(exp(e).kind == TensorExp::Kind::kCmpI ||
357          exp(e).kind == TensorExp::Kind::kCmpF);
358   const LatSetId sNew = conjSet(e, s0, s1, nullptr);
359 
360   ExprId e0 = exp(e).children.e0;
361   ExprId e1 = exp(e).children.e1;
362   if (exp(e0).kind == TensorExp::Kind::kSynZero ||
363       exp(e1).kind == TensorExp::Kind::kSynZero) {
364     // lhs and rhs can't be synthetic zero at the same time.
365     assert(exp(e0).kind != exp(e1).kind);
366     // If one of the operands has already been assigned to zero (the
367     // element is absent in the corresponding operand), then we do not
368     // need to build disjunctive set for it.
369     return sNew;
370   }
371 
372   auto lhsSet = mapBinWithSynZeroSet(e, s0, false);
373   auto rhsSet = mapBinWithSynZeroSet(e, s1, true);
374   latSets[sNew].append(latSets[lhsSet]);
375   latSets[sNew].append(latSets[rhsSet]);
376   return sNew;
377 }
378 
379 LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
380                           bool includeLeft, TensorExp::Kind ltrans,
381                           Operation *opleft, bool includeRight,
382                           TensorExp::Kind rtrans, Operation *opright) {
383   const LatSetId sNew = conjSet(e, s0, s1, orig);
384   // Left Region.
385   if (includeLeft) {
386     if (opleft)
387       s0 = mapSet(ltrans, s0, Value(), opleft);
388     latSets[sNew].append(latSets[s0]);
389   }
390   // Right Region.
391   if (includeRight) {
392     if (opright)
393       s1 = mapSet(rtrans, s1, Value(), opright);
394     latSets[sNew].append(latSets[s1]);
395   }
396   return sNew;
397 }
398 
399 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
400                         Operation *op) {
401   assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
402          TensorExp::Kind::kDenseOp == kind);
403   const LatSetId sNew = addSet();
404   auto &setNew = latSets[sNew];
405   for (const LatPointId p : set(s0)) {
406     const auto &point = latPoints[p];
407     setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op)));
408   }
409   return sNew;
410 }
411 
412 LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) {
413   TensorExp::Kind kind = exp(e).kind;
414   Attribute a = exp(e).attr;
415   assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI);
416   // Must be a binary operation.
417   const LatSetId sNew = addSet();
418   auto &setNew = latSets[sNew];
419   const ExprId zeroExp = addSynZeroExp();
420   for (const LatPointId p : set(s0)) {
421     const auto &point = latPoints[p];
422     ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a)
423                             : addExp(kind, point.exp, zeroExp, nullptr, a);
424     setNew.push_back(addLat(point.bits, newExp));
425   }
426   return sNew;
427 }
428 
429 LatSetId Merger::optimizeSet(LatSetId s0) {
430   const LatSetId sNew = addSet();
431   auto &setNew = latSets[sNew];
432   const auto &set0 = set(s0);
433   assert(!set0.empty());
434   const LatPointId p0 = set0[0];
435   for (const LatPointId p1 : set0) {
436     bool add = true;
437     if (p0 != p1) {
438       // Check whether this is a straightforward copy.
439       if (expIsTensor(latPoints[p1].exp, outTensor))
440         continue;
441       // Check whether this conjunction is already covered.
442       for (const LatPointId p2 : setNew) {
443         assert(!latGT(p1, p2)); // Lj => Li would be bad
444         if (onlyDenseDiff(p2, p1)) {
445           add = false;
446           break;
447         }
448       }
449       assert(!add || latGT(p0, p1));
450     }
451     if (add)
452       setNew.push_back(p1);
453   }
454   for (const LatPointId p : setNew)
455     latPoints[p].simple = simplifyCond(sNew, p);
456   return sNew;
457 }
458 
459 BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
460   // First determine if this lattice point is a *singleton*, i.e.,
461   // the last point in a lattice, no other is less than this one.
462   bool isSingleton = true;
463   for (const LatPointId p1 : set(s0)) {
464     if (p0 != p1 && latGT(p0, p1)) {
465       isSingleton = false;
466       break;
467     }
468   }
469 
470   BitVector simple(latPoints[p0].bits);
471   bool reset = isSingleton && hasAnySparse(simple);
472   const TensorLoopId be = simple.size();
473   TensorLoopId offset = 0; // relative to the end
474   if (!reset)
475     // Starts resetting from a dense level, so that the first bit (if kept)
476     // is not undefined level-type.
477     for (unsigned b = 0; b < be; b++) {
478       if (simple[b] && isDenseLT(getLvlType(TensorLoopId{b}))) {
479         offset = be - b - 1; // relative to the end
480         break;
481       }
482     }
483 
484   // Now apply the two basic rules. We also iterate the bits reversely to always
485   // keep the rightmost bit (which could possibly be a synthetic tensor).
486   for (unsigned b = be - 1 - offset, i = 0; i < be;
487        b = b == 0 ? be - 1 : b - 1, i++) {
488     // Slice on dense level has `locate` property as well, and can be optimized.
489     if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
490       const auto lt = getLvlType(b);
491       if (!isCompressedLT(lt) && !isSingletonLT(lt) &&
492           !isLooseCompressedLT(lt) && !is2OutOf4LT(lt)) {
493         if (reset)
494           simple.reset(b);
495         reset = true;
496       }
497     }
498   }
499   return simple;
500 }
501 
502 bool Merger::latGT(LatPointId i, LatPointId j) const {
503   const BitVector &bitsi = lat(i).bits;
504   const BitVector &bitsj = lat(j).bits;
505   assert(bitsi.size() == bitsj.size());
506   if (bitsi.count() > bitsj.count()) {
507     for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
508       if (bitsj[b] && !bitsi[b])
509         return false;
510     return true;
511   }
512   return false;
513 }
514 
515 bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
516   BitVector tmp(latPoints[j].bits);
517   tmp ^= latPoints[i].bits;
518   return !hasAnySparse(tmp);
519 }
520 
521 bool Merger::expContainsTensor(ExprId e, TensorId t) const {
522   const auto &expr = exp(e);
523   // First we check `expIsTensor`.
524   if (expr.kind == TensorExp::Kind::kTensor)
525     return expr.tensor == t;
526 
527   switch (getExpArity(expr.kind)) {
528   case ExpArity::kNullary:
529     return false;
530   case ExpArity::kUnary: {
531     const ExprId e0 = expr.children.e0;
532     return expContainsTensor(e0, t);
533   }
534   case ExpArity::kBinary: {
535     const ExprId e0 = expr.children.e0;
536     const ExprId e1 = expr.children.e1;
537     return expContainsTensor(e0, t) || expContainsTensor(e1, t);
538   }
539   }
540   llvm_unreachable("unexpected arity");
541 }
542 
543 bool Merger::hasNegateOnOut(ExprId e) const {
544   const auto &expr = exp(e);
545   switch (expr.kind) {
546   case TensorExp::Kind::kNegF:
547   case TensorExp::Kind::kNegC:
548   case TensorExp::Kind::kNegI:
549     return expContainsTensor(expr.children.e0, outTensor);
550   case TensorExp::Kind::kSubF:
551   case TensorExp::Kind::kSubC:
552   case TensorExp::Kind::kSubI:
553     return expContainsTensor(expr.children.e1, outTensor) ||
554            hasNegateOnOut(expr.children.e0);
555   case TensorExp::Kind::kDenseOp: {
556     bool lhsNeg = hasNegateOnOut(expr.children.e0);
557     if (!lhsNeg && expr.children.e1 != detail::kInvalidId)
558       return hasNegateOnOut(expr.children.e1);
559     return lhsNeg;
560   }
561   default: {
562     switch (getExpArity(expr.kind)) {
563     case ExpArity::kNullary:
564       return false;
565     case ExpArity::kUnary:
566       return hasNegateOnOut(expr.children.e0);
567     case ExpArity::kBinary:
568       return hasNegateOnOut(expr.children.e0) ||
569              hasNegateOnOut(expr.children.e1);
570     }
571   }
572   }
573   llvm_unreachable("unexpected kind");
574 }
575 
576 bool Merger::isSingleCondition(TensorId t, ExprId e) const {
577   assert(isValidTensorId(t));
578   const auto &expr = exp(e);
579   switch (expr.kind) {
580   // Leaf.
581   case TensorExp::Kind::kTensor:
582     return expr.tensor == t;
583   case TensorExp::Kind::kInvariant:
584   case TensorExp::Kind::kLoopVar:
585   case TensorExp::Kind::kSynZero:
586     return false;
587   // Unary operations.
588   case TensorExp::Kind::kAbsF:
589   case TensorExp::Kind::kAbsC:
590   case TensorExp::Kind::kAbsI:
591   case TensorExp::Kind::kCeilF:
592   case TensorExp::Kind::kFloorF:
593   case TensorExp::Kind::kSqrtF:
594   case TensorExp::Kind::kSqrtC:
595   case TensorExp::Kind::kExpm1F:
596   case TensorExp::Kind::kExpm1C:
597   case TensorExp::Kind::kLog1pF:
598   case TensorExp::Kind::kLog1pC:
599   case TensorExp::Kind::kSinF:
600   case TensorExp::Kind::kSinC:
601   case TensorExp::Kind::kTanhF:
602   case TensorExp::Kind::kTanhC:
603   case TensorExp::Kind::kNegF:
604   case TensorExp::Kind::kNegC:
605   case TensorExp::Kind::kNegI:
606   case TensorExp::Kind::kTruncF:
607   case TensorExp::Kind::kExtF:
608   case TensorExp::Kind::kCastFS:
609   case TensorExp::Kind::kCastFU:
610   case TensorExp::Kind::kCastSF:
611   case TensorExp::Kind::kCastUF:
612   case TensorExp::Kind::kCastS:
613   case TensorExp::Kind::kCastU:
614   case TensorExp::Kind::kCastIdx:
615   case TensorExp::Kind::kTruncI:
616   case TensorExp::Kind::kCIm:
617   case TensorExp::Kind::kCRe:
618   case TensorExp::Kind::kBitCast:
619   case TensorExp::Kind::kUnary:
620     return isSingleCondition(t, expr.children.e0);
621   case TensorExp::Kind::kBinaryBranch:
622   case TensorExp::Kind::kSelect:
623     return false;
624   // Binary operations.
625   case TensorExp::Kind::kDivF: // note: x / c only
626   case TensorExp::Kind::kDivC:
627   case TensorExp::Kind::kDivS:
628   case TensorExp::Kind::kDivU:
629     assert(!maybeZero(expr.children.e1));
630     return isSingleCondition(t, expr.children.e0);
631   case TensorExp::Kind::kShrS: // note: x >> inv only
632   case TensorExp::Kind::kShrU:
633   case TensorExp::Kind::kShlI:
634     assert(isInvariant(expr.children.e1));
635     return isSingleCondition(t, expr.children.e0);
636   case TensorExp::Kind::kMulF:
637   case TensorExp::Kind::kMulC:
638   case TensorExp::Kind::kMulI:
639   case TensorExp::Kind::kAndI:
640   case TensorExp::Kind::kReduce:
641     if (isSingleCondition(t, expr.children.e0))
642       return isSingleCondition(t, expr.children.e1) ||
643              isInvariant(expr.children.e1);
644     if (isSingleCondition(t, expr.children.e1))
645       return isInvariant(expr.children.e0);
646     return false;
647   case TensorExp::Kind::kAddF:
648   case TensorExp::Kind::kAddC:
649   case TensorExp::Kind::kAddI:
650     return isSingleCondition(t, expr.children.e0) &&
651            isSingleCondition(t, expr.children.e1);
652   case TensorExp::Kind::kSubF:
653   case TensorExp::Kind::kSubC:
654   case TensorExp::Kind::kSubI:
655   case TensorExp::Kind::kOrI:
656   case TensorExp::Kind::kXorI:
657   case TensorExp::Kind::kCmpF:
658   case TensorExp::Kind::kCmpI:
659   case TensorExp::Kind::kBinary:
660     return false;
661   case TensorExp::Kind::kDenseOp:
662     // Since Merger guarantees all the operands of the kDenseOp to be dense, the
663     // operation must be single-condition.
664     return true;
665   }
666   llvm_unreachable("unexpected kind");
667 }
668 
669 bool Merger::hasAnySparse(const BitVector &bits) const {
670   for (TensorLoopId b : bits.set_bits()) {
671     const auto lt = getLvlType(b);
672     if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
673         is2OutOf4LT(lt))
674       return true;
675   }
676   return hasSparseIdxReduction(bits);
677 }
678 
679 bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
680   for (TensorLoopId b : bits.set_bits())
681     if (isSparseLvlWithNonTrivialIdxExp(b))
682       return true;
683   return false;
684 }
685 
686 #ifndef NDEBUG
687 
688 //===----------------------------------------------------------------------===//
689 // Print methods (for debugging).
690 //===----------------------------------------------------------------------===//
691 
692 static const char *kindToOpSymbol(TensorExp::Kind kind) {
693   switch (kind) {
694   // Leaf.
695   case TensorExp::Kind::kTensor:
696     return "tensor";
697   case TensorExp::Kind::kInvariant:
698     return "invariant";
699   case TensorExp::Kind::kLoopVar:
700     return "index";
701   case TensorExp::Kind::kSynZero:
702     return "0";
703   // Unary operations.
704   case TensorExp::Kind::kAbsF:
705   case TensorExp::Kind::kAbsC:
706   case TensorExp::Kind::kAbsI:
707     return "abs";
708   case TensorExp::Kind::kCeilF:
709     return "ceil";
710   case TensorExp::Kind::kFloorF:
711     return "floor";
712   case TensorExp::Kind::kSqrtF:
713   case TensorExp::Kind::kSqrtC:
714     return "sqrt";
715   case TensorExp::Kind::kExpm1F:
716   case TensorExp::Kind::kExpm1C:
717     return "expm1";
718   case TensorExp::Kind::kLog1pF:
719   case TensorExp::Kind::kLog1pC:
720     return "log1p";
721   case TensorExp::Kind::kSinF:
722   case TensorExp::Kind::kSinC:
723     return "sin";
724   case TensorExp::Kind::kTanhF:
725   case TensorExp::Kind::kTanhC:
726     return "tanh";
727   case TensorExp::Kind::kNegF:
728   case TensorExp::Kind::kNegC:
729   case TensorExp::Kind::kNegI:
730     return "-";
731   case TensorExp::Kind::kTruncF:
732   case TensorExp::Kind::kExtF:
733   case TensorExp::Kind::kCastFS:
734   case TensorExp::Kind::kCastFU:
735   case TensorExp::Kind::kCastSF:
736   case TensorExp::Kind::kCastUF:
737   case TensorExp::Kind::kCastS:
738   case TensorExp::Kind::kCastU:
739   case TensorExp::Kind::kCastIdx:
740   case TensorExp::Kind::kTruncI:
741   case TensorExp::Kind::kCIm:
742     return "complex.im";
743   case TensorExp::Kind::kCRe:
744     return "complex.re";
745   case TensorExp::Kind::kBitCast:
746     return "cast";
747   case TensorExp::Kind::kBinaryBranch:
748     return "binary_branch";
749   case TensorExp::Kind::kUnary:
750     return "unary";
751   case TensorExp::Kind::kSelect:
752     return "select";
753   // Binary operations.
754   case TensorExp::Kind::kMulF:
755   case TensorExp::Kind::kMulC:
756   case TensorExp::Kind::kMulI:
757     return "*";
758   case TensorExp::Kind::kDivF:
759   case TensorExp::Kind::kDivC:
760   case TensorExp::Kind::kDivS:
761   case TensorExp::Kind::kDivU:
762     return "/";
763   case TensorExp::Kind::kAddF:
764   case TensorExp::Kind::kAddC:
765   case TensorExp::Kind::kAddI:
766     return "+";
767   case TensorExp::Kind::kSubF:
768   case TensorExp::Kind::kSubC:
769   case TensorExp::Kind::kSubI:
770     return "-";
771   case TensorExp::Kind::kAndI:
772     return "&";
773   case TensorExp::Kind::kOrI:
774     return "|";
775   case TensorExp::Kind::kXorI:
776     return "^";
777   case TensorExp::Kind::kShrS:
778     return "a>>";
779   case TensorExp::Kind::kShrU:
780     return ">>";
781   case TensorExp::Kind::kShlI:
782     return "<<";
783   case TensorExp::Kind::kCmpF:
784   case TensorExp::Kind::kCmpI:
785     return "cmp";
786   case TensorExp::Kind::kBinary:
787     return "binary";
788   case TensorExp::Kind::kReduce:
789     return "reduce";
790   case TensorExp::Kind::kDenseOp:
791     return "dense";
792   }
793   llvm_unreachable("unexpected kind for symbol");
794 }
795 
796 void Merger::dumpExp(ExprId e) const {
797   const auto &expr = exp(e);
798   switch (expr.kind) {
799   // Leaf.
800   case TensorExp::Kind::kTensor:
801     if (expr.tensor == syntheticTensor)
802       llvm::dbgs() << "synthetic_";
803     else if (expr.tensor == outTensor)
804       llvm::dbgs() << "output_";
805     llvm::dbgs() << "tensor_" << expr.tensor;
806     break;
807   case TensorExp::Kind::kInvariant:
808     llvm::dbgs() << "invariant";
809     break;
810   case TensorExp::Kind::kSynZero:
811     llvm::dbgs() << "0";
812     break;
813   case TensorExp::Kind::kLoopVar:
814     llvm::dbgs() << "loopvar_" << expr.loop;
815     break;
816   // Unary operations.
817   case TensorExp::Kind::kAbsF:
818   case TensorExp::Kind::kAbsC:
819   case TensorExp::Kind::kAbsI:
820   case TensorExp::Kind::kCeilF:
821   case TensorExp::Kind::kFloorF:
822   case TensorExp::Kind::kSqrtF:
823   case TensorExp::Kind::kSqrtC:
824   case TensorExp::Kind::kExpm1F:
825   case TensorExp::Kind::kExpm1C:
826   case TensorExp::Kind::kLog1pF:
827   case TensorExp::Kind::kLog1pC:
828   case TensorExp::Kind::kSinF:
829   case TensorExp::Kind::kSinC:
830   case TensorExp::Kind::kTanhF:
831   case TensorExp::Kind::kTanhC:
832   case TensorExp::Kind::kNegF:
833   case TensorExp::Kind::kNegC:
834   case TensorExp::Kind::kNegI:
835   case TensorExp::Kind::kTruncF:
836   case TensorExp::Kind::kExtF:
837   case TensorExp::Kind::kCastFS:
838   case TensorExp::Kind::kCastFU:
839   case TensorExp::Kind::kCastSF:
840   case TensorExp::Kind::kCastUF:
841   case TensorExp::Kind::kCastS:
842   case TensorExp::Kind::kCastU:
843   case TensorExp::Kind::kCastIdx:
844   case TensorExp::Kind::kTruncI:
845   case TensorExp::Kind::kCIm:
846   case TensorExp::Kind::kCRe:
847   case TensorExp::Kind::kBitCast:
848   case TensorExp::Kind::kBinaryBranch:
849   case TensorExp::Kind::kUnary:
850   case TensorExp::Kind::kSelect:
851     llvm::dbgs() << kindToOpSymbol(expr.kind) << " ";
852     dumpExp(expr.children.e0);
853     break;
854   // Binary operations.
855   case TensorExp::Kind::kMulF:
856   case TensorExp::Kind::kMulC:
857   case TensorExp::Kind::kMulI:
858   case TensorExp::Kind::kDivF:
859   case TensorExp::Kind::kDivC:
860   case TensorExp::Kind::kDivS:
861   case TensorExp::Kind::kDivU:
862   case TensorExp::Kind::kAddF:
863   case TensorExp::Kind::kAddC:
864   case TensorExp::Kind::kAddI:
865   case TensorExp::Kind::kSubF:
866   case TensorExp::Kind::kSubC:
867   case TensorExp::Kind::kSubI:
868   case TensorExp::Kind::kAndI:
869   case TensorExp::Kind::kOrI:
870   case TensorExp::Kind::kXorI:
871   case TensorExp::Kind::kShrS:
872   case TensorExp::Kind::kShrU:
873   case TensorExp::Kind::kShlI:
874   case TensorExp::Kind::kCmpF:
875   case TensorExp::Kind::kCmpI:
876   case TensorExp::Kind::kBinary:
877   case TensorExp::Kind::kReduce:
878   case TensorExp::Kind::kDenseOp:
879     llvm::dbgs() << "(";
880     dumpExp(expr.children.e0);
881     llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
882     if (expr.attr)
883       llvm::dbgs() << "{" << expr.attr << "}";
884     if (expr.children.e1 != detail::kInvalidId) {
885       llvm::dbgs() << " ";
886       dumpExp(expr.children.e1);
887       llvm::dbgs() << ")";
888     } else {
889       assert(expr.kind == TensorExp::Kind::kDenseOp);
890     }
891     break;
892   }
893 }
894 
895 void Merger::dumpLat(LatPointId p) const {
896   const auto &point = lat(p);
897   llvm::dbgs() << "lat(";
898   dumpBits(point.bits);
899   llvm::dbgs() << " :";
900   dumpBits(point.simple);
901   llvm::dbgs() << " : ";
902   dumpExp(point.exp);
903   llvm::dbgs() << " )\n";
904 }
905 
906 void Merger::dumpSet(LatSetId s) const {
907   const auto &ss = set(s);
908   llvm::dbgs() << "{ #" << ss.size() << "\n";
909   for (const LatPointId p : ss) {
910     llvm::dbgs() << "  ";
911     dumpLat(p);
912   }
913   llvm::dbgs() << "}\n";
914 }
915 
916 void Merger::dumpBits(const BitVector &bits) const {
917   for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
918     if (bits[b]) {
919       const TensorId t = tensor(b);
920       const LoopId i = loop(b);
921       const auto lt = lvlTypes[t][i];
922       if (isLvlWithNonTrivialIdxExp(b))
923         llvm::dbgs() << " DEP_" << t << "_" << i;
924       else
925         llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt);
926     }
927   }
928 }
929 
930 #endif // NDEBUG
931 
932 //===----------------------------------------------------------------------===//
933 // Builder methods.
934 //===----------------------------------------------------------------------===//
935 
936 LatSetId Merger::buildLattices(ExprId e, LoopId i) {
937   // NOTE: The `expr` reference will be invalidated by recursive calls
938   // (and any other method that may add new expressions); therefore, the
939   // code below must make sure to copy fields of `expr` into local variables
940   // before making any recursive calls.
941   const auto &expr = exp(e);
942   const TensorExp::Kind kind = expr.kind;
943   switch (kind) {
944   // Leaf.
945   case TensorExp::Kind::kTensor:
946   case TensorExp::Kind::kInvariant:
947   case TensorExp::Kind::kSynZero:
948   case TensorExp::Kind::kLoopVar: {
949     // Either the loop-var is really used in the tensor expression, or it is
950     // set to the undefined loop-var in that level. An invariant expression,
951     // a proper index value, and a truly dynamic sparse output tensor are set
952     // to a synthetic tensor with undefined indices only to ensure the
953     // iteration space is not skipped as a result of their contents.
954     const LatSetId s = addSet();
955     TensorId t = syntheticTensor;
956     if (kind == TensorExp::Kind::kTensor) {
957       t = expr.tensor;
958       if (hasSparseOut && t == outTensor)
959         t = syntheticTensor;
960     }
961     latSets[s].push_back(addLat(t, i, e));
962     return s;
963   }
964   // Unary operations.
965   case TensorExp::Kind::kAbsF:
966   case TensorExp::Kind::kAbsC:
967   case TensorExp::Kind::kAbsI:
968   case TensorExp::Kind::kCeilF:
969   case TensorExp::Kind::kFloorF:
970   case TensorExp::Kind::kSqrtF:
971   case TensorExp::Kind::kSqrtC:
972   case TensorExp::Kind::kExpm1F:
973   case TensorExp::Kind::kExpm1C:
974   case TensorExp::Kind::kLog1pF:
975   case TensorExp::Kind::kLog1pC:
976   case TensorExp::Kind::kSinF:
977   case TensorExp::Kind::kSinC:
978   case TensorExp::Kind::kTanhF:
979   case TensorExp::Kind::kTanhC:
980   case TensorExp::Kind::kNegF:
981   case TensorExp::Kind::kNegC:
982   case TensorExp::Kind::kNegI:
983   case TensorExp::Kind::kTruncF:
984   case TensorExp::Kind::kExtF:
985   case TensorExp::Kind::kCastFS:
986   case TensorExp::Kind::kCastFU:
987   case TensorExp::Kind::kCastSF:
988   case TensorExp::Kind::kCastUF:
989   case TensorExp::Kind::kCastS:
990   case TensorExp::Kind::kCastU:
991   case TensorExp::Kind::kCastIdx:
992   case TensorExp::Kind::kTruncI:
993   case TensorExp::Kind::kCIm:
994   case TensorExp::Kind::kCRe:
995   case TensorExp::Kind::kBitCast:
996     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
997     // lattice set of the operand through the operator into a new set.
998     //
999     //  -y|!y | y |
1000     //  --+---+---+
1001     //    | 0 |-y |
1002     {
1003       const ExprId e0 = expr.children.e0;
1004       const Value v = expr.val;
1005       return mapSet(kind, buildLattices(e0, i), v);
1006     }
1007   case TensorExp::Kind::kBinaryBranch:
1008   case TensorExp::Kind::kSelect:
1009     // The left or right half of a binary operation which has already
1010     // been split into separate operations for each region.
1011     {
1012       const ExprId e0 = expr.children.e0;
1013       Operation *const op = expr.op;
1014       return mapSet(kind, buildLattices(e0, i), Value(), op);
1015     }
1016   case TensorExp::Kind::kUnary:
1017     // A custom unary operation.
1018     //
1019     //  op y|    !y    |     y      |
1020     //  ----+----------+------------+
1021     //      | absent() | present(y) |
1022     {
1023       const ExprId e0 = expr.children.e0;
1024       UnaryOp unop = cast<UnaryOp>(expr.op);
1025       const LatSetId child0 = buildLattices(e0, i);
1026       Region &absentRegion = unop.getAbsentRegion();
1027       if (absentRegion.empty()) {
1028         // Simple mapping over existing values.
1029         return mapSet(kind, child0, Value(), unop);
1030       }
1031       // Use a disjunction with `unop` on the left and the absent value as an
1032       // invariant on the right.
1033       Block &absentBlock = absentRegion.front();
1034       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
1035       const Value absentVal = absentYield.getResult();
1036       const ExprId rhs = addInvariantExp(absentVal);
1037       return disjSet(e, child0, buildLattices(rhs, i), unop);
1038     }
1039   // Binary operations.
1040   case TensorExp::Kind::kMulF:
1041   case TensorExp::Kind::kMulC:
1042   case TensorExp::Kind::kMulI:
1043   case TensorExp::Kind::kAndI:
1044     // A multiplicative operation only needs to be performed
1045     // for the conjunction of sparse iteration spaces.
1046     //
1047     //  x*y|!y | y |
1048     //  ---+---+---+
1049     //  !x | 0 | 0 |
1050     //   x | 0 |x*y|
1051     //
1052     // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
1053     {
1054       const ExprId e0 = expr.children.e0;
1055       const ExprId e1 = expr.children.e1;
1056       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1057     }
1058   case TensorExp::Kind::kDivF:
1059   case TensorExp::Kind::kDivC:
1060   case TensorExp::Kind::kDivS:
1061   case TensorExp::Kind::kDivU:
1062     // A division is tricky, since 0/0, 0/c, c/0 all have
1063     // specific outcomes for floating-point and integers.
1064     // Thus, we need to traverse the full iteration space.
1065     //
1066     //  x/y|!y | y |
1067     //  ---+---+---+
1068     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
1069     //   x |x/0|x/y|  INT: x/0=exception for any x
1070     //
1071     // TODO: for now we "fixed" this by only accepting x/c cases
1072     //       during expression building, so that the conjunction
1073     //       rules applies (viz. x/c = x*(1/c) as far as lattice
1074     //       construction is concerned).
1075     {
1076       const ExprId e0 = expr.children.e0;
1077       const ExprId e1 = expr.children.e1;
1078       assert(!maybeZero(e1));
1079       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1080     }
1081   case TensorExp::Kind::kAddF:
1082   case TensorExp::Kind::kAddC:
1083   case TensorExp::Kind::kAddI:
1084   case TensorExp::Kind::kSubF:
1085   case TensorExp::Kind::kSubC:
1086   case TensorExp::Kind::kSubI:
1087   case TensorExp::Kind::kOrI:
1088   case TensorExp::Kind::kXorI:
1089     // An additive operation needs to be performed
1090     // for the disjunction of sparse iteration spaces.
1091     //
1092     //  x+y|!y | y |    x-y|!y | y |
1093     //  ---+---+---+    ---+---+---+
1094     //  !x | 0 | y |    !x | 0 |-y |
1095     //   x | x |x+y|     x | x |x-y|
1096     {
1097       const ExprId e0 = expr.children.e0;
1098       const ExprId e1 = expr.children.e1;
1099       return disjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1100     }
1101   case TensorExp::Kind::kCmpF:
1102   case TensorExp::Kind::kCmpI:
1103     // A comparison operation needs to be performed
1104     // for the disjunction of sparse iteration spaces.
1105     //
1106     //   x < y |  !y   |   y   |
1107     //  -------+-------+-------+
1108     //     !x  |   0   | 0 < y |
1109     //      x  | x < 0 | x < y |
1110     {
1111       const ExprId e0 = expr.children.e0;
1112       const ExprId e1 = expr.children.e1;
1113       return disjSetWithZero(e, buildLattices(e0, i), buildLattices(e1, i));
1114     }
1115   case TensorExp::Kind::kShrS:
1116   case TensorExp::Kind::kShrU:
1117   case TensorExp::Kind::kShlI:
1118     // A shift operation by an invariant amount (viz. tensor expressions
1119     // can only occur at the left-hand-side of the operator) can be handled
1120     // with the conjunction rule.
1121     {
1122       const ExprId e0 = expr.children.e0;
1123       const ExprId e1 = expr.children.e1;
1124       assert(isInvariant(e1));
1125       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1126     }
1127   case TensorExp::Kind::kBinary:
1128     // A custom binary operation.
1129     //
1130     //  x op y|   !y    |       y      |
1131     //  ------+---------+--------------+
1132     //    !x  |  empty  |   right(y)   |
1133     //     x  | left(x) | overlap(x,y) |
1134     {
1135       const ExprId e0 = expr.children.e0;
1136       const ExprId e1 = expr.children.e1;
1137       BinaryOp binop = cast<BinaryOp>(expr.op);
1138       const LatSetId child0 = buildLattices(e0, i);
1139       const LatSetId child1 = buildLattices(e1, i);
1140       Region &leftRegion = binop.getLeftRegion();
1141       Region &rightRegion = binop.getRightRegion();
1142       // Left Region.
1143       Operation *leftYield = nullptr;
1144       if (!leftRegion.empty()) {
1145         Block &leftBlock = leftRegion.front();
1146         leftYield = leftBlock.getTerminator();
1147       }
1148       // Right Region.
1149       Operation *rightYield = nullptr;
1150       if (!rightRegion.empty()) {
1151         Block &rightBlock = rightRegion.front();
1152         rightYield = rightBlock.getTerminator();
1153       }
1154       bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
1155       bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
1156       return combiSet(e, child0, child1, binop, includeLeft,
1157                       TensorExp::Kind::kBinaryBranch, leftYield, includeRight,
1158                       TensorExp::Kind::kBinaryBranch, rightYield);
1159     }
1160   case TensorExp::Kind::kReduce:
1161     // A custom reduce operation.
1162     {
1163       const ExprId e0 = expr.children.e0;
1164       const ExprId e1 = expr.children.e1;
1165       Operation *const op = expr.op;
1166       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1167     }
1168   case TensorExp::Kind::kDenseOp: {
1169     // It does not really matter whether we use conjunctive/disjunctive set
1170     // here, as all the operands of kDenseOp must be dense, the disjunctive set
1171     // will be optimized into conjunctive set eventually.
1172     if (expr.children.e1 == detail::kInvalidId) {
1173       const ExprId e0 = expr.children.e0;
1174       Operation *const op = expr.op;
1175       return mapSet(kind, buildLattices(e0, i), Value(), op);
1176     }
1177 
1178     const ExprId e0 = expr.children.e0;
1179     const ExprId e1 = expr.children.e1;
1180     Operation *const op = expr.op;
1181     return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1182   }
1183   }
1184   llvm_unreachable("unexpected expression kind");
1185 }
1186 
1187 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
1188   // Build the linalg semantics backward from yield.
1189   Operation *yield = op.getRegion().front().getTerminator();
1190   assert(isa<linalg::YieldOp>(yield));
1191   return buildTensorExp(op, yield->getOperand(0)).first;
1192 }
1193 
1194 /// Only returns false if we are certain this is a nonzero.
1195 bool Merger::maybeZero(ExprId e) const {
1196   const auto &expr = exp(e);
1197   if (expr.kind == TensorExp::Kind::kInvariant) {
1198     if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1199       ArrayAttr arrayAttr = c.getValue();
1200       return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1201              cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1202     }
1203     if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
1204       return c.value() == 0;
1205     if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>())
1206       return c.value().isZero();
1207   }
1208   return true;
1209 }
1210 
1211 Type Merger::inferType(ExprId e, Value src) const {
1212   // Obtain the destination type from the cast node.
1213   Type dtp = exp(e).val.getType();
1214   // Inspect source type. For vector types, apply the same
1215   // vectorization to the destination type.
1216   if (auto vtp = dyn_cast<VectorType>(src.getType()))
1217     return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
1218   return dtp;
1219 }
1220 
1221 /// Ensures that the sparsifier can generate code for expression.
1222 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
1223   // Arguments are always admissible.
1224   if (isa<BlockArgument>(v))
1225     return true;
1226   // Accept index anywhere.
1227   Operation *def = v.getDefiningOp();
1228   if (isa<linalg::IndexOp>(def))
1229     return true;
1230   // Operation defined outside branch.
1231   if (def->getBlock() != block)
1232     return def->getBlock() != op->getBlock(); // invariant?
1233   // Operation defined within branch. Anything is accepted,
1234   // as long as all subexpressions are admissible.
1235   for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
1236     if (!isAdmissibleBranchExp(op, block, def->getOperand(i)))
1237       return false;
1238   return true;
1239 }
1240 
1241 /// Ensures that the sparsifier can generate code for branch.
1242 static bool isAdmissibleBranch(Operation *op, Region &region) {
1243   if (region.empty())
1244     return true;
1245   // Build the semi-ring branch semantics backward from yield.
1246   Operation *yield = region.front().getTerminator();
1247   assert(isa<YieldOp>(yield));
1248   return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
1249 }
1250 
1251 std::pair<std::optional<ExprId>, bool>
1252 Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1253   // Recursion leaves.
1254   if (auto arg = dyn_cast<BlockArgument>(v)) {
1255     const TensorId tid = makeTensorId(arg.getArgNumber());
1256     // Any argument of the generic op that is not marked as a scalar
1257     // argument is considered a tensor, indexed by the implicit loop
1258     // bounds. This includes rank-0 tensor arguments.
1259     if (arg.getOwner()->getParentOp() == op) {
1260       OpOperand &t = op->getOpOperand(tid);
1261       bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr;
1262       if (!op.isScalar(&t))
1263         return {addTensorExp(tid), hasSpDep};
1264       v = t.get(); // get scalar value
1265     }
1266     // Any other argument (marked as scalar argument for the generic op
1267     // or belonging to an enveloping op) is considered invariant.
1268     return {addInvariantExp(v), /*hasSpDep=*/false};
1269   }
1270   // Something defined outside is invariant.
1271   Operation *def = v.getDefiningOp();
1272   if (def->getBlock() != &op.getRegion().front())
1273     return {addInvariantExp(v), /*hasSpDep=*/false};
1274   // Construct index operations.
1275   if (def->getNumOperands() == 0) {
1276     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
1277       return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false};
1278   }
1279 
1280   // Construct unary operations if subexpression can be built.
1281   if (def->getNumOperands() == 1) {
1282     const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
1283     if (x.has_value()) {
1284       const ExprId e = *x;
1285       if (isa<math::AbsFOp>(def))
1286         return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep};
1287       if (isa<complex::AbsOp>(def))
1288         return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep};
1289       if (isa<math::AbsIOp>(def))
1290         return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep};
1291       if (isa<math::CeilOp>(def))
1292         return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep};
1293       if (isa<math::FloorOp>(def))
1294         return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep};
1295       if (isa<math::SqrtOp>(def))
1296         return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep};
1297       if (isa<complex::SqrtOp>(def))
1298         return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep};
1299       if (isa<math::ExpM1Op>(def))
1300         return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep};
1301       if (isa<complex::Expm1Op>(def))
1302         return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep};
1303       if (isa<math::Log1pOp>(def))
1304         return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep};
1305       if (isa<complex::Log1pOp>(def))
1306         return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep};
1307       if (isa<math::SinOp>(def))
1308         return {addExp(TensorExp::Kind::kSinF, e), hasSpDep};
1309       if (isa<complex::SinOp>(def))
1310         return {addExp(TensorExp::Kind::kSinC, e), hasSpDep};
1311       if (isa<math::TanhOp>(def))
1312         return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep};
1313       if (isa<complex::TanhOp>(def))
1314         return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep};
1315       if (isa<arith::NegFOp>(def))
1316         return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std
1317       if (isa<complex::NegOp>(def))
1318         return {addExp(TensorExp::Kind::kNegC, e), hasSpDep};
1319       if (isa<arith::TruncFOp>(def))
1320         return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep};
1321       if (isa<arith::ExtFOp>(def))
1322         return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep};
1323       if (isa<arith::FPToSIOp>(def))
1324         return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep};
1325       if (isa<arith::FPToUIOp>(def))
1326         return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep};
1327       if (isa<arith::SIToFPOp>(def))
1328         return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep};
1329       if (isa<arith::UIToFPOp>(def))
1330         return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep};
1331       if (isa<arith::ExtSIOp>(def))
1332         return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep};
1333       if (isa<arith::ExtUIOp>(def))
1334         return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep};
1335       if (isa<arith::IndexCastOp>(def))
1336         return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep};
1337       if (isa<arith::TruncIOp>(def))
1338         return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep};
1339       if (isa<complex::ImOp>(def))
1340         return {addExp(TensorExp::Kind::kCIm, e), hasSpDep};
1341       if (isa<complex::ReOp>(def))
1342         return {addExp(TensorExp::Kind::kCRe, e), hasSpDep};
1343       if (isa<arith::BitcastOp>(def))
1344         return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep};
1345       if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1346         if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
1347             isAdmissibleBranch(unop, unop.getAbsentRegion()))
1348           return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep};
1349       }
1350       if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1351         if (isAdmissibleBranch(selop, selop.getRegion()))
1352           return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep};
1353       }
1354     }
1355   }
1356   // Construct binary operations if subexpressions can be built.
1357   // See buildLattices() for an explanation of rejecting certain
1358   // division and shift operations.
1359   if (def->getNumOperands() == 2) {
1360     const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1361     const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1362     bool hasSpDep = xDepSp || yDepSp;
1363     if (x.has_value() && y.has_value()) {
1364       const ExprId e0 = *x;
1365       const ExprId e1 = *y;
1366       if (isa<arith::MulFOp>(def))
1367         return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep};
1368       if (isa<complex::MulOp>(def))
1369         return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep};
1370       if (isa<arith::MulIOp>(def))
1371         return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep};
1372       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1373         return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep};
1374       if (isa<complex::DivOp>(def) && !maybeZero(e1))
1375         return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep};
1376       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1377         return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep};
1378       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1379         return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep};
1380       if (isa<arith::AddFOp>(def))
1381         return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep};
1382       if (isa<complex::AddOp>(def))
1383         return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep};
1384       if (isa<arith::AddIOp>(def))
1385         return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep};
1386       if (isa<arith::SubFOp>(def))
1387         return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep};
1388       if (isa<complex::SubOp>(def))
1389         return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep};
1390       if (isa<arith::SubIOp>(def))
1391         return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep};
1392       if (isa<arith::AndIOp>(def))
1393         return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep};
1394       if (isa<arith::OrIOp>(def))
1395         return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep};
1396       if (isa<arith::XOrIOp>(def))
1397         return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep};
1398       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1399         return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep};
1400       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1401         return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep};
1402       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1403         return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep};
1404       if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
1405         if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1406             ci.getPredicate() == arith::CmpIPredicate::sle &&
1407             ci.getPredicate() == arith::CmpIPredicate::sge &&
1408             ci.getPredicate() == arith::CmpIPredicate::ule &&
1409             ci.getPredicate() == arith::CmpIPredicate::uge) {
1410           // We can not sparsify comparison with equal, this is because 0 <= 0
1411           // yields true, and thus densifies the result.
1412           return {std::nullopt, false};
1413         }
1414 
1415         auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
1416                         ci.getPredicateAttr());
1417         return {e, hasSpDep};
1418       }
1419       if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
1420         if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1421             cf.getPredicate() == arith::CmpFPredicate::OGE &&
1422             cf.getPredicate() == arith::CmpFPredicate::OLE &&
1423             cf.getPredicate() == arith::CmpFPredicate::ONE &&
1424             cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1425             cf.getPredicate() == arith::CmpFPredicate::UGE &&
1426             cf.getPredicate() == arith::CmpFPredicate::ULE &&
1427             cf.getPredicate() == arith::CmpFPredicate::ORD &&
1428             cf.getPredicate() == arith::CmpFPredicate::UNO) {
1429           // We can not sparsify comparison with equal, this is because 0 <= 0
1430           // yields true, and thus densifies the result.
1431           return {std::nullopt, false};
1432         }
1433         auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
1434                         cf.getPredicateAttr());
1435         return {e, hasSpDep};
1436       }
1437       if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1438         if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
1439             (binop.getLeftIdentity() ||
1440              isAdmissibleBranch(binop, binop.getLeftRegion())) &&
1441             (binop.getRightIdentity() ||
1442              isAdmissibleBranch(binop, binop.getRightRegion())))
1443           return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep};
1444       }
1445     }
1446   }
1447   // Construct ternary operations if subexpressions can be built.
1448   if (def->getNumOperands() == 3) {
1449     const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1450     const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1451     const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
1452     bool hasSpDep = xDepSp || yDepSp || zDepSp;
1453     if (x.has_value() && y.has_value() && z.has_value()) {
1454       const ExprId e0 = *x;
1455       const ExprId e1 = *y;
1456       if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1457         if (isAdmissibleBranch(redop, redop.getRegion()))
1458           return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
1459       }
1460     }
1461   }
1462 
1463   // If we reach here, we are dealing with an operation that is not currently
1464   // sparsifiable. We can still generate code for it if all its operands only
1465   // have dense dependencies (i.e., all the values are loaded from dense
1466   // tensors).
1467   if (def->getNumResults() != 1) // only handle single result operation.
1468     return {std::nullopt, false};
1469 
1470   SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
1471   // Builds all the sub-expressions
1472   for (Value operand : def->getOperands())
1473     subExp.push_back(buildTensorExp(op, operand));
1474 
1475   if (llvm::all_of(subExp,
1476                    [](auto e) { return e.first.has_value() && !e.second; })) {
1477     // All the subexpressions can be built and has *no* sparse dependencies.
1478     if (subExp.size() == 2) {
1479       auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1480                       *subExp[1].first, def);
1481       return {e, false};
1482     }
1483     if (subExp.size() == 1) {
1484       auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1485                       detail::kInvalidId, def);
1486       return {e, false};
1487     }
1488   }
1489   // Cannot build.
1490   return {std::nullopt, false};
1491 }
1492 
1493 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1494                            ValueRange vals) {
1495   // Make a clone of overlap region.
1496   Region tmpRegion;
1497   IRMapping mapper;
1498   region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1499   Block &clonedBlock = tmpRegion.front();
1500   YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1501   // Merge cloned block and return yield value.
1502   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1503   rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
1504   Value val = clonedYield.getResult();
1505   rewriter.eraseOp(clonedYield);
1506   rewriter.eraseOp(placeholder);
1507   return val;
1508 }
1509 
1510 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
1511                                Operation *op, Value v0) {
1512   if (!v0)
1513     // Empty input value must be propagated.
1514     return Value();
1515   UnaryOp unop = cast<UnaryOp>(op);
1516   Region &presentRegion = unop.getPresentRegion();
1517   if (presentRegion.empty())
1518     // Uninitialized Value() will be interpreted as missing data in the
1519     // output.
1520     return Value();
1521   return insertYieldOp(rewriter, loc, presentRegion, {v0});
1522 }
1523 
1524 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
1525                                 Operation *op, Value v0, Value v1) {
1526   if (!v0 || !v1)
1527     // Empty input values must be propagated.
1528     return Value();
1529   BinaryOp binop = cast<BinaryOp>(op);
1530   Region &overlapRegion = binop.getOverlapRegion();
1531   if (overlapRegion.empty())
1532     // Uninitialized Value() will be interpreted as missing data in the
1533     // output.
1534     return Value();
1535   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1536 }
1537 
1538 Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
1539                        Value v1) const {
1540   const auto &expr = exp(e);
1541   switch (expr.kind) {
1542   // Leaf.
1543   case TensorExp::Kind::kTensor:
1544   case TensorExp::Kind::kInvariant:
1545   case TensorExp::Kind::kLoopVar:
1546   case TensorExp::Kind::kSynZero:
1547     llvm_unreachable("unexpected non-op");
1548   // Unary operations.
1549   case TensorExp::Kind::kAbsF:
1550     return rewriter.create<math::AbsFOp>(loc, v0);
1551   case TensorExp::Kind::kAbsC: {
1552     auto type = cast<ComplexType>(v0.getType());
1553     auto eltType = cast<FloatType>(type.getElementType());
1554     return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1555   }
1556   case TensorExp::Kind::kAbsI:
1557     return rewriter.create<math::AbsIOp>(loc, v0);
1558   case TensorExp::Kind::kCeilF:
1559     return rewriter.create<math::CeilOp>(loc, v0);
1560   case TensorExp::Kind::kFloorF:
1561     return rewriter.create<math::FloorOp>(loc, v0);
1562   case TensorExp::Kind::kSqrtF:
1563     return rewriter.create<math::SqrtOp>(loc, v0);
1564   case TensorExp::Kind::kSqrtC:
1565     return rewriter.create<complex::SqrtOp>(loc, v0);
1566   case TensorExp::Kind::kExpm1F:
1567     return rewriter.create<math::ExpM1Op>(loc, v0);
1568   case TensorExp::Kind::kExpm1C:
1569     return rewriter.create<complex::Expm1Op>(loc, v0);
1570   case TensorExp::Kind::kLog1pF:
1571     return rewriter.create<math::Log1pOp>(loc, v0);
1572   case TensorExp::Kind::kLog1pC:
1573     return rewriter.create<complex::Log1pOp>(loc, v0);
1574   case TensorExp::Kind::kSinF:
1575     return rewriter.create<math::SinOp>(loc, v0);
1576   case TensorExp::Kind::kSinC:
1577     return rewriter.create<complex::SinOp>(loc, v0);
1578   case TensorExp::Kind::kTanhF:
1579     return rewriter.create<math::TanhOp>(loc, v0);
1580   case TensorExp::Kind::kTanhC:
1581     return rewriter.create<complex::TanhOp>(loc, v0);
1582   case TensorExp::Kind::kNegF:
1583     return rewriter.create<arith::NegFOp>(loc, v0);
1584   case TensorExp::Kind::kNegC:
1585     return rewriter.create<complex::NegOp>(loc, v0);
1586   case TensorExp::Kind::kNegI: // no negi in std
1587     return rewriter.create<arith::SubIOp>(
1588         loc,
1589         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1590                                            rewriter.getZeroAttr(v0.getType())),
1591         v0);
1592   case TensorExp::Kind::kTruncF:
1593     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1594   case TensorExp::Kind::kExtF:
1595     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1596   case TensorExp::Kind::kCastFS:
1597     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1598   case TensorExp::Kind::kCastFU:
1599     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1600   case TensorExp::Kind::kCastSF:
1601     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1602   case TensorExp::Kind::kCastUF:
1603     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1604   case TensorExp::Kind::kCastS:
1605     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1606   case TensorExp::Kind::kCastU:
1607     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1608   case TensorExp::Kind::kCastIdx:
1609     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1610   case TensorExp::Kind::kTruncI:
1611     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1612   case TensorExp::Kind::kCIm: {
1613     auto type = cast<ComplexType>(v0.getType());
1614     auto eltType = cast<FloatType>(type.getElementType());
1615     return rewriter.create<complex::ImOp>(loc, eltType, v0);
1616   }
1617   case TensorExp::Kind::kCRe: {
1618     auto type = cast<ComplexType>(v0.getType());
1619     auto eltType = cast<FloatType>(type.getElementType());
1620     return rewriter.create<complex::ReOp>(loc, eltType, v0);
1621   }
1622   case TensorExp::Kind::kBitCast:
1623     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1624   // Binary operations.
1625   case TensorExp::Kind::kMulF:
1626     return rewriter.create<arith::MulFOp>(loc, v0, v1);
1627   case TensorExp::Kind::kMulC:
1628     return rewriter.create<complex::MulOp>(loc, v0, v1);
1629   case TensorExp::Kind::kMulI:
1630     return rewriter.create<arith::MulIOp>(loc, v0, v1);
1631   case TensorExp::Kind::kDivF:
1632     return rewriter.create<arith::DivFOp>(loc, v0, v1);
1633   case TensorExp::Kind::kDivC:
1634     return rewriter.create<complex::DivOp>(loc, v0, v1);
1635   case TensorExp::Kind::kDivS:
1636     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1637   case TensorExp::Kind::kDivU:
1638     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1639   case TensorExp::Kind::kAddF:
1640     return rewriter.create<arith::AddFOp>(loc, v0, v1);
1641   case TensorExp::Kind::kAddC:
1642     return rewriter.create<complex::AddOp>(loc, v0, v1);
1643   case TensorExp::Kind::kAddI:
1644     return rewriter.create<arith::AddIOp>(loc, v0, v1);
1645   case TensorExp::Kind::kSubF:
1646     return rewriter.create<arith::SubFOp>(loc, v0, v1);
1647   case TensorExp::Kind::kSubC:
1648     return rewriter.create<complex::SubOp>(loc, v0, v1);
1649   case TensorExp::Kind::kSubI:
1650     return rewriter.create<arith::SubIOp>(loc, v0, v1);
1651   case TensorExp::Kind::kAndI:
1652     return rewriter.create<arith::AndIOp>(loc, v0, v1);
1653   case TensorExp::Kind::kOrI:
1654     return rewriter.create<arith::OrIOp>(loc, v0, v1);
1655   case TensorExp::Kind::kXorI:
1656     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1657   case TensorExp::Kind::kShrS:
1658     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1659   case TensorExp::Kind::kShrU:
1660     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1661   case TensorExp::Kind::kShlI:
1662     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1663   case TensorExp::Kind::kCmpI: {
1664     auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr);
1665     return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1);
1666   }
1667   case TensorExp::Kind::kCmpF: {
1668     auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr);
1669     return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1);
1670   }
1671   case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
1672     return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(),
1673                          {v0});
1674   case TensorExp::Kind::kUnary:
1675     return buildUnaryPresent(rewriter, loc, expr.op, v0);
1676   case TensorExp::Kind::kSelect:
1677     return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
1678                          {v0});
1679   case TensorExp::Kind::kBinary:
1680     return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
1681   case TensorExp::Kind::kReduce: {
1682     ReduceOp redOp = cast<ReduceOp>(expr.op);
1683     return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1684   }
1685   case TensorExp::Kind::kDenseOp: {
1686     Operation *actualOp = expr.op;
1687     IRMapping mapping;
1688     mapping.map(actualOp->getOperand(0), v0);
1689     if (actualOp->getNumOperands() == 2)
1690       mapping.map(actualOp->getOperand(1), v1);
1691     return rewriter.clone(*actualOp, mapping)->getResult(0);
1692   }
1693   }
1694   llvm_unreachable("unexpected expression kind in build");
1695 }
1696 
1697 } // namespace sparse_tensor
1698 } // namespace mlir
1699