xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (revision 2cb99df6090d8e219d03c55b2a40f99b9be692ed)
1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 #include <optional>
18 
19 namespace mlir {
20 namespace sparse_tensor {
21 
22 enum class ExpArity {
23   kNullary,
24   kUnary,
25   kBinary,
26 };
27 
28 static ExpArity getExpArity(TensorExp::Kind k) {
29   switch (k) {
30   // Leaf.
31   case TensorExp::Kind::kTensor:
32   case TensorExp::Kind::kInvariant:
33   case TensorExp::Kind::kLoopVar:
34   case TensorExp::Kind::kSynZero:
35     return ExpArity::kNullary;
36   case TensorExp::Kind::kAbsF:
37   case TensorExp::Kind::kAbsC:
38   case TensorExp::Kind::kAbsI:
39   case TensorExp::Kind::kCeilF:
40   case TensorExp::Kind::kFloorF:
41   case TensorExp::Kind::kSqrtF:
42   case TensorExp::Kind::kSqrtC:
43   case TensorExp::Kind::kExpm1F:
44   case TensorExp::Kind::kExpm1C:
45   case TensorExp::Kind::kLog1pF:
46   case TensorExp::Kind::kLog1pC:
47   case TensorExp::Kind::kSinF:
48   case TensorExp::Kind::kSinC:
49   case TensorExp::Kind::kTanhF:
50   case TensorExp::Kind::kTanhC:
51   case TensorExp::Kind::kTruncF:
52   case TensorExp::Kind::kExtF:
53   case TensorExp::Kind::kCastFS:
54   case TensorExp::Kind::kCastFU:
55   case TensorExp::Kind::kCastSF:
56   case TensorExp::Kind::kCastUF:
57   case TensorExp::Kind::kCastS:
58   case TensorExp::Kind::kCastU:
59   case TensorExp::Kind::kCastIdx:
60   case TensorExp::Kind::kTruncI:
61   case TensorExp::Kind::kCIm:
62   case TensorExp::Kind::kCRe:
63   case TensorExp::Kind::kBitCast:
64   case TensorExp::Kind::kBinaryBranch:
65   case TensorExp::Kind::kUnary:
66   case TensorExp::Kind::kSelect:
67   case TensorExp::Kind::kNegF:
68   case TensorExp::Kind::kNegC:
69   case TensorExp::Kind::kNegI:
70     return ExpArity::kUnary;
71   // Binary operations.
72   case TensorExp::Kind::kDivF:
73   case TensorExp::Kind::kDivC:
74   case TensorExp::Kind::kDivS:
75   case TensorExp::Kind::kDivU:
76   case TensorExp::Kind::kShrS:
77   case TensorExp::Kind::kShrU:
78   case TensorExp::Kind::kShlI:
79   case TensorExp::Kind::kMulF:
80   case TensorExp::Kind::kMulC:
81   case TensorExp::Kind::kMulI:
82   case TensorExp::Kind::kAndI:
83   case TensorExp::Kind::kAddF:
84   case TensorExp::Kind::kAddC:
85   case TensorExp::Kind::kAddI:
86   case TensorExp::Kind::kOrI:
87   case TensorExp::Kind::kXorI:
88   case TensorExp::Kind::kBinary:
89   case TensorExp::Kind::kReduce:
90   case TensorExp::Kind::kSubF:
91   case TensorExp::Kind::kSubC:
92   case TensorExp::Kind::kSubI:
93   case TensorExp::Kind::kCmpF:
94   case TensorExp::Kind::kCmpI:
95   case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands
96     return ExpArity::kBinary;
97   }
98   llvm_unreachable("unexpected kind");
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // Constructors.
103 //===----------------------------------------------------------------------===//
104 
105 TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
106                      Operation *o, Attribute a)
107     : kind(k), val(v), op(o) {
108   switch (kind) {
109   // Leaf.
110   case TensorExp::Kind::kTensor:
111     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
112     tensor = x;
113     return;
114   case TensorExp::Kind::kSynZero:
115     assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o);
116     return;
117   case TensorExp::Kind::kInvariant:
118     assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o);
119     return;
120   case TensorExp::Kind::kLoopVar:
121     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
122     loop = x;
123     return;
124   // Unary operations.
125   case TensorExp::Kind::kAbsF:
126   case TensorExp::Kind::kAbsC:
127   case TensorExp::Kind::kAbsI:
128   case TensorExp::Kind::kCeilF:
129   case TensorExp::Kind::kFloorF:
130   case TensorExp::Kind::kSqrtF:
131   case TensorExp::Kind::kSqrtC:
132   case TensorExp::Kind::kExpm1F:
133   case TensorExp::Kind::kExpm1C:
134   case TensorExp::Kind::kLog1pF:
135   case TensorExp::Kind::kLog1pC:
136   case TensorExp::Kind::kSinF:
137   case TensorExp::Kind::kSinC:
138   case TensorExp::Kind::kTanhF:
139   case TensorExp::Kind::kTanhC:
140   case TensorExp::Kind::kNegF:
141   case TensorExp::Kind::kNegC:
142   case TensorExp::Kind::kNegI:
143   case TensorExp::Kind::kCIm:
144   case TensorExp::Kind::kCRe:
145     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
146     children.e0 = x;
147     children.e1 = y;
148     return;
149   case TensorExp::Kind::kTruncF:
150   case TensorExp::Kind::kExtF:
151   case TensorExp::Kind::kCastFS:
152   case TensorExp::Kind::kCastFU:
153   case TensorExp::Kind::kCastSF:
154   case TensorExp::Kind::kCastUF:
155   case TensorExp::Kind::kCastS:
156   case TensorExp::Kind::kCastU:
157   case TensorExp::Kind::kCastIdx:
158   case TensorExp::Kind::kTruncI:
159   case TensorExp::Kind::kBitCast:
160     assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o);
161     children.e0 = x;
162     children.e1 = y;
163     return;
164   case TensorExp::Kind::kBinaryBranch:
165   case TensorExp::Kind::kSelect:
166     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o);
167     children.e0 = x;
168     children.e1 = y;
169     return;
170   case TensorExp::Kind::kUnary:
171     // No assertion on y can be made, as the branching paths involve both
172     // a unary (`mapSet`) and binary (`disjSet`) pathway.
173     assert(x != detail::kInvalidId && !v && o);
174     children.e0 = x;
175     children.e1 = y;
176     return;
177   // Binary operations.
178   case TensorExp::Kind::kMulF:
179   case TensorExp::Kind::kMulC:
180   case TensorExp::Kind::kMulI:
181   case TensorExp::Kind::kDivF:
182   case TensorExp::Kind::kDivC:
183   case TensorExp::Kind::kDivS:
184   case TensorExp::Kind::kDivU:
185   case TensorExp::Kind::kAddF:
186   case TensorExp::Kind::kAddC:
187   case TensorExp::Kind::kAddI:
188   case TensorExp::Kind::kSubF:
189   case TensorExp::Kind::kSubC:
190   case TensorExp::Kind::kSubI:
191   case TensorExp::Kind::kAndI:
192   case TensorExp::Kind::kOrI:
193   case TensorExp::Kind::kXorI:
194   case TensorExp::Kind::kShrS:
195   case TensorExp::Kind::kShrU:
196   case TensorExp::Kind::kShlI:
197     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
198     children.e0 = x;
199     children.e1 = y;
200     return;
201   case TensorExp::Kind::kCmpF:
202   case TensorExp::Kind::kCmpI:
203     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
204     attr = a;
205     children.e0 = x;
206     children.e1 = y;
207     return;
208   case TensorExp::Kind::kBinary:
209   case TensorExp::Kind::kReduce:
210     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o);
211     children.e0 = x;
212     children.e1 = y;
213     return;
214   case TensorExp::Kind::kDenseOp:
215     assert(x != detail::kInvalidId && !v && o);
216     children.e0 = x;
217     children.e1 = y;
218     return;
219   }
220   llvm_unreachable("unexpected kind");
221 }
222 
223 Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
224                unsigned numFilterLoops, unsigned maxLvlRank)
225     : outTensor(numInputOutputTensors - 1),
226       syntheticTensor(numInputOutputTensors),
227       numTensors(numInputOutputTensors + 1), numNativeLoops(numNativeLoops),
228       numLoops(numNativeLoops + numFilterLoops), hasSparseOut(false),
229       lvlTypes(numTensors,
230                std::vector<DimLevelType>(numLoops, DimLevelType::Undef)),
231       loopToLvl(numTensors,
232                 std::vector<std::optional<Level>>(numLoops, std::nullopt)),
233       lvlToLoop(numTensors,
234                 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
235       loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlDLTPair>>(
236                                          numTensors, std::nullopt)),
237       levelToDependentLoop(numTensors,
238                            std::vector<std::vector<LoopCoeffPair>>(
239                                maxLvlRank, std::vector<LoopCoeffPair>())),
240       loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
241 
242 //===----------------------------------------------------------------------===//
243 // Lattice methods.
244 //===----------------------------------------------------------------------===//
245 
246 ExprId Merger::addTensorExp(TensorId t) {
247   assert(isValidTensorId(t));
248   const ExprId eNew(tensorExps.size());
249   tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId,
250                           Value(), nullptr, nullptr);
251   return eNew;
252 }
253 
254 ExprId Merger::addLoopVarExp(LoopId i) {
255   assert(isValidLoopId(i));
256   const ExprId eNew(tensorExps.size());
257   tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId,
258                           Value(), nullptr, nullptr);
259   return eNew;
260 }
261 
262 ExprId Merger::addInvariantExp(Value v) {
263   const ExprId eNew(tensorExps.size());
264   tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId,
265                           detail::kInvalidId, v, nullptr, nullptr);
266   return eNew;
267 }
268 
269 ExprId Merger::addSynZeroExp() {
270   const ExprId eNew(tensorExps.size());
271   tensorExps.emplace_back(TensorExp::Kind::kSynZero, detail::kInvalidId,
272                           detail::kInvalidId, Value(), nullptr, nullptr);
273   return eNew;
274 }
275 
276 ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op,
277                       Attribute attr) {
278   assert(k > TensorExp::Kind::kLoopVar);
279   const ExprId eNew(tensorExps.size());
280   tensorExps.emplace_back(k, e0, e1, Value(), op, attr);
281   return eNew;
282 }
283 
284 ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op,
285                       Attribute attr) {
286   assert(k > TensorExp::Kind::kLoopVar);
287   const ExprId eNew(tensorExps.size());
288   tensorExps.emplace_back(k, e, detail::kInvalidId, v, op, attr);
289   return eNew;
290 }
291 
292 LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {
293   const LatPointId pNew(latPoints.size());
294   const unsigned size = numLoops * numTensors;
295   const TensorLoopId b = makeTensorLoopId(t, i);
296   latPoints.emplace_back(size, e);
297   latPoints[pNew].bits.set(b);
298   return pNew;
299 }
300 
301 LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
302   assert(bits.size() == numLoops * numTensors);
303   const LatPointId pNew(latPoints.size());
304   latPoints.emplace_back(bits, e);
305   return pNew;
306 }
307 
308 LatSetId Merger::addSet() {
309   const LatSetId sNew(latSets.size());
310   latSets.emplace_back();
311   return sNew;
312 }
313 
314 LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1,
315                            Operation *op) {
316   TensorExp::Kind kind = exp(e).kind;
317   Attribute attr = exp(e).attr;
318   const LatPointId pNew(latPoints.size());
319   const auto &point0 = lat(p0);
320   const auto &point1 = lat(p1);
321   BitVector bits(point0.bits);
322   bits |= point1.bits;
323   const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr);
324   latPoints.emplace_back(bits, ne);
325   return pNew;
326 }
327 
328 LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
329   const LatSetId sNew = addSet();
330   auto &setNew = latSets[sNew];
331   for (const LatPointId p0 : set(s0))
332     for (const LatPointId p1 : set(s1))
333       setNew.push_back(conjLat(e, p0, p1, op));
334   return sNew;
335 }
336 
337 LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
338   const LatSetId sNew = conjSet(e, s0, s1, op);
339   TensorExp::Kind kind = exp(e).kind;
340 
341   // Followed by all in s0.
342   latSets[sNew].append(latSets[s0]);
343   // Map binary 0-y to unary -y.
344   // TODO: move this if-else logic into buildLattices
345   if (kind == TensorExp::Kind::kSubF)
346     s1 = mapSet(TensorExp::Kind::kNegF, s1);
347   else if (kind == TensorExp::Kind::kSubC)
348     s1 = mapSet(TensorExp::Kind::kNegC, s1);
349   else if (kind == TensorExp::Kind::kSubI)
350     s1 = mapSet(TensorExp::Kind::kNegI, s1);
351   // Followed by all in s1.
352   latSets[sNew].append(latSets[s1]);
353   return sNew;
354 }
355 
356 LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) {
357   assert(exp(e).kind == TensorExp::Kind::kCmpI ||
358          exp(e).kind == TensorExp::Kind::kCmpF);
359   const LatSetId sNew = conjSet(e, s0, s1, nullptr);
360 
361   ExprId e0 = exp(e).children.e0;
362   ExprId e1 = exp(e).children.e1;
363   if (exp(e0).kind == TensorExp::Kind::kSynZero ||
364       exp(e1).kind == TensorExp::Kind::kSynZero) {
365     // lhs and rhs can't be synthetic zero at the same time.
366     assert(exp(e0).kind != exp(e1).kind);
367     // If one of the operands has already been assigned to zero (the
368     // element is absent in the corresponding operand), then we do not
369     // need to build disjunctive set for it.
370     return sNew;
371   }
372 
373   auto lhsSet = mapBinWithSynZeroSet(e, s0, false);
374   auto rhsSet = mapBinWithSynZeroSet(e, s1, true);
375   latSets[sNew].append(latSets[lhsSet]);
376   latSets[sNew].append(latSets[rhsSet]);
377   return sNew;
378 }
379 
380 LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
381                           bool includeLeft, TensorExp::Kind ltrans,
382                           Operation *opleft, bool includeRight,
383                           TensorExp::Kind rtrans, Operation *opright) {
384   const LatSetId sNew = conjSet(e, s0, s1, orig);
385   // Left Region.
386   if (includeLeft) {
387     if (opleft)
388       s0 = mapSet(ltrans, s0, Value(), opleft);
389     latSets[sNew].append(latSets[s0]);
390   }
391   // Right Region.
392   if (includeRight) {
393     if (opright)
394       s1 = mapSet(rtrans, s1, Value(), opright);
395     latSets[sNew].append(latSets[s1]);
396   }
397   return sNew;
398 }
399 
400 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
401                         Operation *op) {
402   assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
403          TensorExp::Kind::kDenseOp == kind);
404   const LatSetId sNew = addSet();
405   auto &setNew = latSets[sNew];
406   for (const LatPointId p : set(s0)) {
407     const auto &point = latPoints[p];
408     setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op)));
409   }
410   return sNew;
411 }
412 
413 LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) {
414   TensorExp::Kind kind = exp(e).kind;
415   Attribute a = exp(e).attr;
416   assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI);
417   // Must be a binary operation.
418   const LatSetId sNew = addSet();
419   auto &setNew = latSets[sNew];
420   const ExprId zeroExp = addSynZeroExp();
421   for (const LatPointId p : set(s0)) {
422     const auto &point = latPoints[p];
423     ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a)
424                             : addExp(kind, point.exp, zeroExp, nullptr, a);
425     setNew.push_back(addLat(point.bits, newExp));
426   }
427   return sNew;
428 }
429 
430 LatSetId Merger::optimizeSet(LatSetId s0) {
431   const LatSetId sNew = addSet();
432   auto &setNew = latSets[sNew];
433   const auto &set0 = set(s0);
434   assert(!set0.empty());
435   const LatPointId p0 = set0[0];
436   for (const LatPointId p1 : set0) {
437     bool add = true;
438     if (p0 != p1) {
439       // Check whether this is a straightforward copy.
440       if (expIsTensor(latPoints[p1].exp, outTensor))
441         continue;
442       // Check whether this conjunction is already covered.
443       for (const LatPointId p2 : setNew) {
444         assert(!latGT(p1, p2)); // Lj => Li would be bad
445         if (onlyDenseDiff(p2, p1)) {
446           add = false;
447           break;
448         }
449       }
450       assert(!add || latGT(p0, p1));
451     }
452     if (add)
453       setNew.push_back(p1);
454   }
455   for (const LatPointId p : setNew)
456     latPoints[p].simple = simplifyCond(sNew, p);
457   return sNew;
458 }
459 
460 BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
461   // First determine if this lattice point is a *singleton*, i.e.,
462   // the last point in a lattice, no other is less than this one.
463   bool isSingleton = true;
464   for (const LatPointId p1 : set(s0)) {
465     if (p0 != p1 && latGT(p0, p1)) {
466       isSingleton = false;
467       break;
468     }
469   }
470 
471   BitVector simple(latPoints[p0].bits);
472   bool reset = isSingleton && hasAnySparse(simple);
473   const TensorLoopId be = simple.size();
474   TensorLoopId offset = 0; // relative to the end
475   if (!reset)
476     // Starts resetting from a dense level, so that the first bit (if kept)
477     // is not undefined level-type.
478     for (unsigned b = 0; b < be; b++) {
479       if (simple[b] && isDenseDLT(getLvlType(TensorLoopId{b}))) {
480         offset = be - b - 1; // relative to the end
481         break;
482       }
483     }
484 
485   // Now apply the two basic rules. We also iterate the bits reversely to always
486   // keep the rightmost bit (which could possibly be a synthetic tensor).
487   for (unsigned b = be - 1 - offset, i = 0; i < be;
488        b = b == 0 ? be - 1 : b - 1, i++) {
489     // Slice on dense level has `locate` property as well, and can be optimized.
490     if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
491       const auto dlt = getLvlType(b);
492       if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) &&
493           !isLooseCompressedDLT(dlt)) {
494         if (reset)
495           simple.reset(b);
496         reset = true;
497       }
498     }
499   }
500   return simple;
501 }
502 
503 bool Merger::latGT(LatPointId i, LatPointId j) const {
504   const BitVector &bitsi = lat(i).bits;
505   const BitVector &bitsj = lat(j).bits;
506   assert(bitsi.size() == bitsj.size());
507   if (bitsi.count() > bitsj.count()) {
508     for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
509       if (bitsj[b] && !bitsi[b])
510         return false;
511     return true;
512   }
513   return false;
514 }
515 
516 bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
517   BitVector tmp(latPoints[j].bits);
518   tmp ^= latPoints[i].bits;
519   return !hasAnySparse(tmp);
520 }
521 
522 bool Merger::expContainsTensor(ExprId e, TensorId t) const {
523   const auto &expr = exp(e);
524   // First we check `expIsTensor`.
525   if (expr.kind == TensorExp::Kind::kTensor)
526     return expr.tensor == t;
527 
528   switch (getExpArity(expr.kind)) {
529   case ExpArity::kNullary:
530     return false;
531   case ExpArity::kUnary: {
532     const ExprId e0 = expr.children.e0;
533     return expContainsTensor(e0, t);
534   }
535   case ExpArity::kBinary: {
536     const ExprId e0 = expr.children.e0;
537     const ExprId e1 = expr.children.e1;
538     return expContainsTensor(e0, t) || expContainsTensor(e1, t);
539   }
540   }
541   llvm_unreachable("unexpected arity");
542 }
543 
544 bool Merger::hasNegateOnOut(ExprId e) const {
545   const auto &expr = exp(e);
546   switch (expr.kind) {
547   case TensorExp::Kind::kNegF:
548   case TensorExp::Kind::kNegC:
549   case TensorExp::Kind::kNegI:
550     return expContainsTensor(expr.children.e0, outTensor);
551   case TensorExp::Kind::kSubF:
552   case TensorExp::Kind::kSubC:
553   case TensorExp::Kind::kSubI:
554     return expContainsTensor(expr.children.e1, outTensor) ||
555            hasNegateOnOut(expr.children.e0);
556   case TensorExp::Kind::kDenseOp: {
557     bool lhsNeg = hasNegateOnOut(expr.children.e0);
558     if (!lhsNeg && expr.children.e1 != detail::kInvalidId)
559       return hasNegateOnOut(expr.children.e1);
560     return lhsNeg;
561   }
562   default: {
563     switch (getExpArity(expr.kind)) {
564     case ExpArity::kNullary:
565       return false;
566     case ExpArity::kUnary:
567       return hasNegateOnOut(expr.children.e0);
568     case ExpArity::kBinary:
569       return hasNegateOnOut(expr.children.e0) ||
570              hasNegateOnOut(expr.children.e1);
571     }
572   }
573   }
574   llvm_unreachable("unexpected kind");
575 }
576 
577 bool Merger::isSingleCondition(TensorId t, ExprId e) const {
578   assert(isValidTensorId(t));
579   const auto &expr = exp(e);
580   switch (expr.kind) {
581   // Leaf.
582   case TensorExp::Kind::kTensor:
583     return expr.tensor == t;
584   case TensorExp::Kind::kInvariant:
585   case TensorExp::Kind::kLoopVar:
586   case TensorExp::Kind::kSynZero:
587     return false;
588   // Unary operations.
589   case TensorExp::Kind::kAbsF:
590   case TensorExp::Kind::kAbsC:
591   case TensorExp::Kind::kAbsI:
592   case TensorExp::Kind::kCeilF:
593   case TensorExp::Kind::kFloorF:
594   case TensorExp::Kind::kSqrtF:
595   case TensorExp::Kind::kSqrtC:
596   case TensorExp::Kind::kExpm1F:
597   case TensorExp::Kind::kExpm1C:
598   case TensorExp::Kind::kLog1pF:
599   case TensorExp::Kind::kLog1pC:
600   case TensorExp::Kind::kSinF:
601   case TensorExp::Kind::kSinC:
602   case TensorExp::Kind::kTanhF:
603   case TensorExp::Kind::kTanhC:
604   case TensorExp::Kind::kNegF:
605   case TensorExp::Kind::kNegC:
606   case TensorExp::Kind::kNegI:
607   case TensorExp::Kind::kTruncF:
608   case TensorExp::Kind::kExtF:
609   case TensorExp::Kind::kCastFS:
610   case TensorExp::Kind::kCastFU:
611   case TensorExp::Kind::kCastSF:
612   case TensorExp::Kind::kCastUF:
613   case TensorExp::Kind::kCastS:
614   case TensorExp::Kind::kCastU:
615   case TensorExp::Kind::kCastIdx:
616   case TensorExp::Kind::kTruncI:
617   case TensorExp::Kind::kCIm:
618   case TensorExp::Kind::kCRe:
619   case TensorExp::Kind::kBitCast:
620   case TensorExp::Kind::kUnary:
621     return isSingleCondition(t, expr.children.e0);
622   case TensorExp::Kind::kBinaryBranch:
623   case TensorExp::Kind::kSelect:
624     return false;
625   // Binary operations.
626   case TensorExp::Kind::kDivF: // note: x / c only
627   case TensorExp::Kind::kDivC:
628   case TensorExp::Kind::kDivS:
629   case TensorExp::Kind::kDivU:
630     assert(!maybeZero(expr.children.e1));
631     return isSingleCondition(t, expr.children.e0);
632   case TensorExp::Kind::kShrS: // note: x >> inv only
633   case TensorExp::Kind::kShrU:
634   case TensorExp::Kind::kShlI:
635     assert(isInvariant(expr.children.e1));
636     return isSingleCondition(t, expr.children.e0);
637   case TensorExp::Kind::kMulF:
638   case TensorExp::Kind::kMulC:
639   case TensorExp::Kind::kMulI:
640   case TensorExp::Kind::kAndI:
641   case TensorExp::Kind::kReduce:
642     if (isSingleCondition(t, expr.children.e0))
643       return isSingleCondition(t, expr.children.e1) ||
644              isInvariant(expr.children.e1);
645     if (isSingleCondition(t, expr.children.e1))
646       return isInvariant(expr.children.e0);
647     return false;
648   case TensorExp::Kind::kAddF:
649   case TensorExp::Kind::kAddC:
650   case TensorExp::Kind::kAddI:
651     return isSingleCondition(t, expr.children.e0) &&
652            isSingleCondition(t, expr.children.e1);
653   case TensorExp::Kind::kSubF:
654   case TensorExp::Kind::kSubC:
655   case TensorExp::Kind::kSubI:
656   case TensorExp::Kind::kOrI:
657   case TensorExp::Kind::kXorI:
658   case TensorExp::Kind::kCmpF:
659   case TensorExp::Kind::kCmpI:
660   case TensorExp::Kind::kBinary:
661     return false;
662   case TensorExp::Kind::kDenseOp:
663     // Since Merger guarantees all the operands of the kDenseOp to be dense, the
664     // operation must be single-condition.
665     return true;
666   }
667   llvm_unreachable("unexpected kind");
668 }
669 
670 bool Merger::hasAnySparse(const BitVector &bits) const {
671   for (TensorLoopId b : bits.set_bits()) {
672     const auto dlt = getLvlType(b);
673     if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
674         isLooseCompressedDLT(dlt))
675       return true;
676   }
677   return hasSparseIdxReduction(bits);
678 }
679 
680 bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
681   for (TensorLoopId b : bits.set_bits())
682     if (isSparseLvlWithNonTrivialIdxExp(b))
683       return true;
684   return false;
685 }
686 
687 #ifndef NDEBUG
688 
689 //===----------------------------------------------------------------------===//
690 // Print methods (for debugging).
691 //===----------------------------------------------------------------------===//
692 
693 static const char *kindToOpSymbol(TensorExp::Kind kind) {
694   switch (kind) {
695   // Leaf.
696   case TensorExp::Kind::kTensor:
697     return "tensor";
698   case TensorExp::Kind::kInvariant:
699     return "invariant";
700   case TensorExp::Kind::kLoopVar:
701     return "index";
702   case TensorExp::Kind::kSynZero:
703     return "0";
704   // Unary operations.
705   case TensorExp::Kind::kAbsF:
706   case TensorExp::Kind::kAbsC:
707   case TensorExp::Kind::kAbsI:
708     return "abs";
709   case TensorExp::Kind::kCeilF:
710     return "ceil";
711   case TensorExp::Kind::kFloorF:
712     return "floor";
713   case TensorExp::Kind::kSqrtF:
714   case TensorExp::Kind::kSqrtC:
715     return "sqrt";
716   case TensorExp::Kind::kExpm1F:
717   case TensorExp::Kind::kExpm1C:
718     return "expm1";
719   case TensorExp::Kind::kLog1pF:
720   case TensorExp::Kind::kLog1pC:
721     return "log1p";
722   case TensorExp::Kind::kSinF:
723   case TensorExp::Kind::kSinC:
724     return "sin";
725   case TensorExp::Kind::kTanhF:
726   case TensorExp::Kind::kTanhC:
727     return "tanh";
728   case TensorExp::Kind::kNegF:
729   case TensorExp::Kind::kNegC:
730   case TensorExp::Kind::kNegI:
731     return "-";
732   case TensorExp::Kind::kTruncF:
733   case TensorExp::Kind::kExtF:
734   case TensorExp::Kind::kCastFS:
735   case TensorExp::Kind::kCastFU:
736   case TensorExp::Kind::kCastSF:
737   case TensorExp::Kind::kCastUF:
738   case TensorExp::Kind::kCastS:
739   case TensorExp::Kind::kCastU:
740   case TensorExp::Kind::kCastIdx:
741   case TensorExp::Kind::kTruncI:
742   case TensorExp::Kind::kCIm:
743     return "complex.im";
744   case TensorExp::Kind::kCRe:
745     return "complex.re";
746   case TensorExp::Kind::kBitCast:
747     return "cast";
748   case TensorExp::Kind::kBinaryBranch:
749     return "binary_branch";
750   case TensorExp::Kind::kUnary:
751     return "unary";
752   case TensorExp::Kind::kSelect:
753     return "select";
754   // Binary operations.
755   case TensorExp::Kind::kMulF:
756   case TensorExp::Kind::kMulC:
757   case TensorExp::Kind::kMulI:
758     return "*";
759   case TensorExp::Kind::kDivF:
760   case TensorExp::Kind::kDivC:
761   case TensorExp::Kind::kDivS:
762   case TensorExp::Kind::kDivU:
763     return "/";
764   case TensorExp::Kind::kAddF:
765   case TensorExp::Kind::kAddC:
766   case TensorExp::Kind::kAddI:
767     return "+";
768   case TensorExp::Kind::kSubF:
769   case TensorExp::Kind::kSubC:
770   case TensorExp::Kind::kSubI:
771     return "-";
772   case TensorExp::Kind::kAndI:
773     return "&";
774   case TensorExp::Kind::kOrI:
775     return "|";
776   case TensorExp::Kind::kXorI:
777     return "^";
778   case TensorExp::Kind::kShrS:
779     return "a>>";
780   case TensorExp::Kind::kShrU:
781     return ">>";
782   case TensorExp::Kind::kShlI:
783     return "<<";
784   case TensorExp::Kind::kCmpF:
785   case TensorExp::Kind::kCmpI:
786     return "cmp";
787   case TensorExp::Kind::kBinary:
788     return "binary";
789   case TensorExp::Kind::kReduce:
790     return "reduce";
791   case TensorExp::Kind::kDenseOp:
792     return "dense";
793   }
794   llvm_unreachable("unexpected kind for symbol");
795 }
796 
797 void Merger::dumpExp(ExprId e) const {
798   const auto &expr = exp(e);
799   switch (expr.kind) {
800   // Leaf.
801   case TensorExp::Kind::kTensor:
802     if (expr.tensor == syntheticTensor)
803       llvm::dbgs() << "synthetic_";
804     else if (expr.tensor == outTensor)
805       llvm::dbgs() << "output_";
806     llvm::dbgs() << "tensor_" << expr.tensor;
807     break;
808   case TensorExp::Kind::kInvariant:
809     llvm::dbgs() << "invariant";
810     break;
811   case TensorExp::Kind::kSynZero:
812     llvm::dbgs() << "0";
813     break;
814   case TensorExp::Kind::kLoopVar:
815     llvm::dbgs() << "loopvar_" << expr.loop;
816     break;
817   // Unary operations.
818   case TensorExp::Kind::kAbsF:
819   case TensorExp::Kind::kAbsC:
820   case TensorExp::Kind::kAbsI:
821   case TensorExp::Kind::kCeilF:
822   case TensorExp::Kind::kFloorF:
823   case TensorExp::Kind::kSqrtF:
824   case TensorExp::Kind::kSqrtC:
825   case TensorExp::Kind::kExpm1F:
826   case TensorExp::Kind::kExpm1C:
827   case TensorExp::Kind::kLog1pF:
828   case TensorExp::Kind::kLog1pC:
829   case TensorExp::Kind::kSinF:
830   case TensorExp::Kind::kSinC:
831   case TensorExp::Kind::kTanhF:
832   case TensorExp::Kind::kTanhC:
833   case TensorExp::Kind::kNegF:
834   case TensorExp::Kind::kNegC:
835   case TensorExp::Kind::kNegI:
836   case TensorExp::Kind::kTruncF:
837   case TensorExp::Kind::kExtF:
838   case TensorExp::Kind::kCastFS:
839   case TensorExp::Kind::kCastFU:
840   case TensorExp::Kind::kCastSF:
841   case TensorExp::Kind::kCastUF:
842   case TensorExp::Kind::kCastS:
843   case TensorExp::Kind::kCastU:
844   case TensorExp::Kind::kCastIdx:
845   case TensorExp::Kind::kTruncI:
846   case TensorExp::Kind::kCIm:
847   case TensorExp::Kind::kCRe:
848   case TensorExp::Kind::kBitCast:
849   case TensorExp::Kind::kBinaryBranch:
850   case TensorExp::Kind::kUnary:
851   case TensorExp::Kind::kSelect:
852     llvm::dbgs() << kindToOpSymbol(expr.kind) << " ";
853     dumpExp(expr.children.e0);
854     break;
855   // Binary operations.
856   case TensorExp::Kind::kMulF:
857   case TensorExp::Kind::kMulC:
858   case TensorExp::Kind::kMulI:
859   case TensorExp::Kind::kDivF:
860   case TensorExp::Kind::kDivC:
861   case TensorExp::Kind::kDivS:
862   case TensorExp::Kind::kDivU:
863   case TensorExp::Kind::kAddF:
864   case TensorExp::Kind::kAddC:
865   case TensorExp::Kind::kAddI:
866   case TensorExp::Kind::kSubF:
867   case TensorExp::Kind::kSubC:
868   case TensorExp::Kind::kSubI:
869   case TensorExp::Kind::kAndI:
870   case TensorExp::Kind::kOrI:
871   case TensorExp::Kind::kXorI:
872   case TensorExp::Kind::kShrS:
873   case TensorExp::Kind::kShrU:
874   case TensorExp::Kind::kShlI:
875   case TensorExp::Kind::kCmpF:
876   case TensorExp::Kind::kCmpI:
877   case TensorExp::Kind::kBinary:
878   case TensorExp::Kind::kReduce:
879   case TensorExp::Kind::kDenseOp:
880     llvm::dbgs() << "(";
881     dumpExp(expr.children.e0);
882     llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
883     if (expr.attr)
884       llvm::dbgs() << "{" << expr.attr << "}";
885     if (expr.children.e1 != detail::kInvalidId) {
886       llvm::dbgs() << " ";
887       dumpExp(expr.children.e1);
888       llvm::dbgs() << ")";
889     } else {
890       assert(expr.kind == TensorExp::Kind::kDenseOp);
891     }
892     break;
893   }
894 }
895 
896 void Merger::dumpLat(LatPointId p) const {
897   const auto &point = lat(p);
898   llvm::dbgs() << "lat(";
899   dumpBits(point.bits);
900   llvm::dbgs() << " :";
901   dumpBits(point.simple);
902   llvm::dbgs() << " : ";
903   dumpExp(point.exp);
904   llvm::dbgs() << " )\n";
905 }
906 
907 void Merger::dumpSet(LatSetId s) const {
908   const auto &ss = set(s);
909   llvm::dbgs() << "{ #" << ss.size() << "\n";
910   for (const LatPointId p : ss) {
911     llvm::dbgs() << "  ";
912     dumpLat(p);
913   }
914   llvm::dbgs() << "}\n";
915 }
916 
917 void Merger::dumpBits(const BitVector &bits) const {
918   for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
919     if (bits[b]) {
920       const TensorId t = tensor(b);
921       const LoopId i = loop(b);
922       const auto dlt = lvlTypes[t][i];
923       if (isLvlWithNonTrivialIdxExp(b))
924         llvm::dbgs() << " DEP_" << t << "_" << i;
925       else
926         llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt);
927     }
928   }
929 }
930 
931 #endif // NDEBUG
932 
933 //===----------------------------------------------------------------------===//
934 // Builder methods.
935 //===----------------------------------------------------------------------===//
936 
937 LatSetId Merger::buildLattices(ExprId e, LoopId i) {
938   // NOTE: The `expr` reference will be invalidated by recursive calls
939   // (and any other method that may add new expressions); therefore, the
940   // code below must make sure to copy fields of `expr` into local variables
941   // before making any recursive calls.
942   const auto &expr = exp(e);
943   const TensorExp::Kind kind = expr.kind;
944   switch (kind) {
945   // Leaf.
946   case TensorExp::Kind::kTensor:
947   case TensorExp::Kind::kInvariant:
948   case TensorExp::Kind::kSynZero:
949   case TensorExp::Kind::kLoopVar: {
950     // Either the loop-var is really used in the tensor expression, or it is
951     // set to the undefined loop-var in that level. An invariant expression,
952     // a proper index value, and a truly dynamic sparse output tensor are set
953     // to a synthetic tensor with undefined indices only to ensure the
954     // iteration space is not skipped as a result of their contents.
955     const LatSetId s = addSet();
956     TensorId t = syntheticTensor;
957     if (kind == TensorExp::Kind::kTensor) {
958       t = expr.tensor;
959       if (hasSparseOut && t == outTensor)
960         t = syntheticTensor;
961     }
962     latSets[s].push_back(addLat(t, i, e));
963     return s;
964   }
965   // Unary operations.
966   case TensorExp::Kind::kAbsF:
967   case TensorExp::Kind::kAbsC:
968   case TensorExp::Kind::kAbsI:
969   case TensorExp::Kind::kCeilF:
970   case TensorExp::Kind::kFloorF:
971   case TensorExp::Kind::kSqrtF:
972   case TensorExp::Kind::kSqrtC:
973   case TensorExp::Kind::kExpm1F:
974   case TensorExp::Kind::kExpm1C:
975   case TensorExp::Kind::kLog1pF:
976   case TensorExp::Kind::kLog1pC:
977   case TensorExp::Kind::kSinF:
978   case TensorExp::Kind::kSinC:
979   case TensorExp::Kind::kTanhF:
980   case TensorExp::Kind::kTanhC:
981   case TensorExp::Kind::kNegF:
982   case TensorExp::Kind::kNegC:
983   case TensorExp::Kind::kNegI:
984   case TensorExp::Kind::kTruncF:
985   case TensorExp::Kind::kExtF:
986   case TensorExp::Kind::kCastFS:
987   case TensorExp::Kind::kCastFU:
988   case TensorExp::Kind::kCastSF:
989   case TensorExp::Kind::kCastUF:
990   case TensorExp::Kind::kCastS:
991   case TensorExp::Kind::kCastU:
992   case TensorExp::Kind::kCastIdx:
993   case TensorExp::Kind::kTruncI:
994   case TensorExp::Kind::kCIm:
995   case TensorExp::Kind::kCRe:
996   case TensorExp::Kind::kBitCast:
997     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
998     // lattice set of the operand through the operator into a new set.
999     //
1000     //  -y|!y | y |
1001     //  --+---+---+
1002     //    | 0 |-y |
1003     {
1004       const ExprId e0 = expr.children.e0;
1005       const Value v = expr.val;
1006       return mapSet(kind, buildLattices(e0, i), v);
1007     }
1008   case TensorExp::Kind::kBinaryBranch:
1009   case TensorExp::Kind::kSelect:
1010     // The left or right half of a binary operation which has already
1011     // been split into separate operations for each region.
1012     {
1013       const ExprId e0 = expr.children.e0;
1014       Operation *const op = expr.op;
1015       return mapSet(kind, buildLattices(e0, i), Value(), op);
1016     }
1017   case TensorExp::Kind::kUnary:
1018     // A custom unary operation.
1019     //
1020     //  op y|    !y    |     y      |
1021     //  ----+----------+------------+
1022     //      | absent() | present(y) |
1023     {
1024       const ExprId e0 = expr.children.e0;
1025       UnaryOp unop = cast<UnaryOp>(expr.op);
1026       const LatSetId child0 = buildLattices(e0, i);
1027       Region &absentRegion = unop.getAbsentRegion();
1028       if (absentRegion.empty()) {
1029         // Simple mapping over existing values.
1030         return mapSet(kind, child0, Value(), unop);
1031       }
1032       // Use a disjunction with `unop` on the left and the absent value as an
1033       // invariant on the right.
1034       Block &absentBlock = absentRegion.front();
1035       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
1036       const Value absentVal = absentYield.getResult();
1037       const ExprId rhs = addInvariantExp(absentVal);
1038       return disjSet(e, child0, buildLattices(rhs, i), unop);
1039     }
1040   // Binary operations.
1041   case TensorExp::Kind::kMulF:
1042   case TensorExp::Kind::kMulC:
1043   case TensorExp::Kind::kMulI:
1044   case TensorExp::Kind::kAndI:
1045     // A multiplicative operation only needs to be performed
1046     // for the conjunction of sparse iteration spaces.
1047     //
1048     //  x*y|!y | y |
1049     //  ---+---+---+
1050     //  !x | 0 | 0 |
1051     //   x | 0 |x*y|
1052     //
1053     // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
1054     {
1055       const ExprId e0 = expr.children.e0;
1056       const ExprId e1 = expr.children.e1;
1057       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1058     }
1059   case TensorExp::Kind::kDivF:
1060   case TensorExp::Kind::kDivC:
1061   case TensorExp::Kind::kDivS:
1062   case TensorExp::Kind::kDivU:
1063     // A division is tricky, since 0/0, 0/c, c/0 all have
1064     // specific outcomes for floating-point and integers.
1065     // Thus, we need to traverse the full iteration space.
1066     //
1067     //  x/y|!y | y |
1068     //  ---+---+---+
1069     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
1070     //   x |x/0|x/y|  INT: x/0=exception for any x
1071     //
1072     // TODO: for now we "fixed" this by only accepting x/c cases
1073     //       during expression building, so that the conjunction
1074     //       rules applies (viz. x/c = x*(1/c) as far as lattice
1075     //       construction is concerned).
1076     {
1077       const ExprId e0 = expr.children.e0;
1078       const ExprId e1 = expr.children.e1;
1079       assert(!maybeZero(e1));
1080       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1081     }
1082   case TensorExp::Kind::kAddF:
1083   case TensorExp::Kind::kAddC:
1084   case TensorExp::Kind::kAddI:
1085   case TensorExp::Kind::kSubF:
1086   case TensorExp::Kind::kSubC:
1087   case TensorExp::Kind::kSubI:
1088   case TensorExp::Kind::kOrI:
1089   case TensorExp::Kind::kXorI:
1090     // An additive operation needs to be performed
1091     // for the disjunction of sparse iteration spaces.
1092     //
1093     //  x+y|!y | y |    x-y|!y | y |
1094     //  ---+---+---+    ---+---+---+
1095     //  !x | 0 | y |    !x | 0 |-y |
1096     //   x | x |x+y|     x | x |x-y|
1097     {
1098       const ExprId e0 = expr.children.e0;
1099       const ExprId e1 = expr.children.e1;
1100       return disjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1101     }
1102   case TensorExp::Kind::kCmpF:
1103   case TensorExp::Kind::kCmpI:
1104     // A comparison operation needs to be performed
1105     // for the disjunction of sparse iteration spaces.
1106     //
1107     //   x < y |  !y   |   y   |
1108     //  -------+-------+-------+
1109     //     !x  |   0   | 0 < y |
1110     //      x  | x < 0 | x < y |
1111     {
1112       const ExprId e0 = expr.children.e0;
1113       const ExprId e1 = expr.children.e1;
1114       return disjSetWithZero(e, buildLattices(e0, i), buildLattices(e1, i));
1115     }
1116   case TensorExp::Kind::kShrS:
1117   case TensorExp::Kind::kShrU:
1118   case TensorExp::Kind::kShlI:
1119     // A shift operation by an invariant amount (viz. tensor expressions
1120     // can only occur at the left-hand-side of the operator) can be handled
1121     // with the conjunction rule.
1122     {
1123       const ExprId e0 = expr.children.e0;
1124       const ExprId e1 = expr.children.e1;
1125       assert(isInvariant(e1));
1126       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1127     }
1128   case TensorExp::Kind::kBinary:
1129     // A custom binary operation.
1130     //
1131     //  x op y|   !y    |       y      |
1132     //  ------+---------+--------------+
1133     //    !x  |  empty  |   right(y)   |
1134     //     x  | left(x) | overlap(x,y) |
1135     {
1136       const ExprId e0 = expr.children.e0;
1137       const ExprId e1 = expr.children.e1;
1138       BinaryOp binop = cast<BinaryOp>(expr.op);
1139       const LatSetId child0 = buildLattices(e0, i);
1140       const LatSetId child1 = buildLattices(e1, i);
1141       Region &leftRegion = binop.getLeftRegion();
1142       Region &rightRegion = binop.getRightRegion();
1143       // Left Region.
1144       Operation *leftYield = nullptr;
1145       if (!leftRegion.empty()) {
1146         Block &leftBlock = leftRegion.front();
1147         leftYield = leftBlock.getTerminator();
1148       }
1149       // Right Region.
1150       Operation *rightYield = nullptr;
1151       if (!rightRegion.empty()) {
1152         Block &rightBlock = rightRegion.front();
1153         rightYield = rightBlock.getTerminator();
1154       }
1155       bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
1156       bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
1157       return combiSet(e, child0, child1, binop, includeLeft,
1158                       TensorExp::Kind::kBinaryBranch, leftYield, includeRight,
1159                       TensorExp::Kind::kBinaryBranch, rightYield);
1160     }
1161   case TensorExp::Kind::kReduce:
1162     // A custom reduce operation.
1163     {
1164       const ExprId e0 = expr.children.e0;
1165       const ExprId e1 = expr.children.e1;
1166       Operation *const op = expr.op;
1167       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1168     }
1169   case TensorExp::Kind::kDenseOp: {
1170     // It does not really matter whether we use conjunctive/disjunctive set
1171     // here, as all the operands of kDenseOp must be dense, the disjunctive set
1172     // will be optimized into conjunctive set eventually.
1173     if (expr.children.e1 == detail::kInvalidId) {
1174       const ExprId e0 = expr.children.e0;
1175       Operation *const op = expr.op;
1176       return mapSet(kind, buildLattices(e0, i), Value(), op);
1177     }
1178 
1179     const ExprId e0 = expr.children.e0;
1180     const ExprId e1 = expr.children.e1;
1181     Operation *const op = expr.op;
1182     return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1183   }
1184   }
1185   llvm_unreachable("unexpected expression kind");
1186 }
1187 
1188 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
1189   // Build the linalg semantics backward from yield.
1190   Operation *yield = op.getRegion().front().getTerminator();
1191   assert(isa<linalg::YieldOp>(yield));
1192   return buildTensorExp(op, yield->getOperand(0)).first;
1193 }
1194 
1195 /// Only returns false if we are certain this is a nonzero.
1196 bool Merger::maybeZero(ExprId e) const {
1197   const auto &expr = exp(e);
1198   if (expr.kind == TensorExp::Kind::kInvariant) {
1199     if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1200       ArrayAttr arrayAttr = c.getValue();
1201       return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1202              cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1203     }
1204     if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
1205       return c.value() == 0;
1206     if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>())
1207       return c.value().isZero();
1208   }
1209   return true;
1210 }
1211 
1212 Type Merger::inferType(ExprId e, Value src) const {
1213   // Obtain the destination type from the cast node.
1214   Type dtp = exp(e).val.getType();
1215   // Inspect source type. For vector types, apply the same
1216   // vectorization to the destination type.
1217   if (auto vtp = dyn_cast<VectorType>(src.getType()))
1218     return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
1219   return dtp;
1220 }
1221 
1222 /// Ensures that sparse compiler can generate code for expression.
1223 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
1224   // Arguments are always admissible.
1225   if (isa<BlockArgument>(v))
1226     return true;
1227   // Accept index anywhere.
1228   Operation *def = v.getDefiningOp();
1229   if (isa<linalg::IndexOp>(def))
1230     return true;
1231   // Operation defined outside branch.
1232   if (def->getBlock() != block)
1233     return def->getBlock() != op->getBlock(); // invariant?
1234   // Operation defined within branch. Anything is accepted,
1235   // as long as all subexpressions are admissible.
1236   for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
1237     if (!isAdmissibleBranchExp(op, block, def->getOperand(i)))
1238       return false;
1239   return true;
1240 }
1241 
1242 /// Ensures that sparse compiler can generate code for branch.
1243 static bool isAdmissibleBranch(Operation *op, Region &region) {
1244   if (region.empty())
1245     return true;
1246   // Build the semi-ring branch semantics backward from yield.
1247   Operation *yield = region.front().getTerminator();
1248   assert(isa<YieldOp>(yield));
1249   return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
1250 }
1251 
1252 std::pair<std::optional<ExprId>, bool>
1253 Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1254   // Recursion leaves.
1255   if (auto arg = dyn_cast<BlockArgument>(v)) {
1256     const TensorId tid = makeTensorId(arg.getArgNumber());
1257     // Any argument of the generic op that is not marked as a scalar
1258     // argument is considered a tensor, indexed by the implicit loop
1259     // bounds. This includes rank-0 tensor arguments.
1260     if (arg.getOwner()->getParentOp() == op) {
1261       OpOperand &t = op->getOpOperand(tid);
1262       bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr;
1263       if (!op.isScalar(&t))
1264         return {addTensorExp(tid), hasSpDep};
1265       v = t.get(); // get scalar value
1266     }
1267     // Any other argument (marked as scalar argument for the generic op
1268     // or belonging to an enveloping op) is considered invariant.
1269     return {addInvariantExp(v), /*hasSpDep=*/false};
1270   }
1271   // Something defined outside is invariant.
1272   Operation *def = v.getDefiningOp();
1273   if (def->getBlock() != &op.getRegion().front())
1274     return {addInvariantExp(v), /*hasSpDep=*/false};
1275   // Construct index operations.
1276   if (def->getNumOperands() == 0) {
1277     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
1278       return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false};
1279   }
1280 
1281   // Construct unary operations if subexpression can be built.
1282   if (def->getNumOperands() == 1) {
1283     const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
1284     if (x.has_value()) {
1285       const ExprId e = *x;
1286       if (isa<math::AbsFOp>(def))
1287         return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep};
1288       if (isa<complex::AbsOp>(def))
1289         return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep};
1290       if (isa<math::AbsIOp>(def))
1291         return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep};
1292       if (isa<math::CeilOp>(def))
1293         return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep};
1294       if (isa<math::FloorOp>(def))
1295         return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep};
1296       if (isa<math::SqrtOp>(def))
1297         return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep};
1298       if (isa<complex::SqrtOp>(def))
1299         return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep};
1300       if (isa<math::ExpM1Op>(def))
1301         return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep};
1302       if (isa<complex::Expm1Op>(def))
1303         return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep};
1304       if (isa<math::Log1pOp>(def))
1305         return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep};
1306       if (isa<complex::Log1pOp>(def))
1307         return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep};
1308       if (isa<math::SinOp>(def))
1309         return {addExp(TensorExp::Kind::kSinF, e), hasSpDep};
1310       if (isa<complex::SinOp>(def))
1311         return {addExp(TensorExp::Kind::kSinC, e), hasSpDep};
1312       if (isa<math::TanhOp>(def))
1313         return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep};
1314       if (isa<complex::TanhOp>(def))
1315         return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep};
1316       if (isa<arith::NegFOp>(def))
1317         return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std
1318       if (isa<complex::NegOp>(def))
1319         return {addExp(TensorExp::Kind::kNegC, e), hasSpDep};
1320       if (isa<arith::TruncFOp>(def))
1321         return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep};
1322       if (isa<arith::ExtFOp>(def))
1323         return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep};
1324       if (isa<arith::FPToSIOp>(def))
1325         return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep};
1326       if (isa<arith::FPToUIOp>(def))
1327         return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep};
1328       if (isa<arith::SIToFPOp>(def))
1329         return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep};
1330       if (isa<arith::UIToFPOp>(def))
1331         return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep};
1332       if (isa<arith::ExtSIOp>(def))
1333         return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep};
1334       if (isa<arith::ExtUIOp>(def))
1335         return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep};
1336       if (isa<arith::IndexCastOp>(def))
1337         return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep};
1338       if (isa<arith::TruncIOp>(def))
1339         return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep};
1340       if (isa<complex::ImOp>(def))
1341         return {addExp(TensorExp::Kind::kCIm, e), hasSpDep};
1342       if (isa<complex::ReOp>(def))
1343         return {addExp(TensorExp::Kind::kCRe, e), hasSpDep};
1344       if (isa<arith::BitcastOp>(def))
1345         return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep};
1346       if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1347         if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
1348             isAdmissibleBranch(unop, unop.getAbsentRegion()))
1349           return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep};
1350       }
1351       if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1352         if (isAdmissibleBranch(selop, selop.getRegion()))
1353           return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep};
1354       }
1355     }
1356   }
1357   // Construct binary operations if subexpressions can be built.
1358   // See buildLattices() for an explanation of rejecting certain
1359   // division and shift operations.
1360   if (def->getNumOperands() == 2) {
1361     const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1362     const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1363     bool hasSpDep = xDepSp || yDepSp;
1364     if (x.has_value() && y.has_value()) {
1365       const ExprId e0 = *x;
1366       const ExprId e1 = *y;
1367       if (isa<arith::MulFOp>(def))
1368         return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep};
1369       if (isa<complex::MulOp>(def))
1370         return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep};
1371       if (isa<arith::MulIOp>(def))
1372         return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep};
1373       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1374         return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep};
1375       if (isa<complex::DivOp>(def) && !maybeZero(e1))
1376         return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep};
1377       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1378         return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep};
1379       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1380         return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep};
1381       if (isa<arith::AddFOp>(def))
1382         return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep};
1383       if (isa<complex::AddOp>(def))
1384         return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep};
1385       if (isa<arith::AddIOp>(def))
1386         return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep};
1387       if (isa<arith::SubFOp>(def))
1388         return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep};
1389       if (isa<complex::SubOp>(def))
1390         return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep};
1391       if (isa<arith::SubIOp>(def))
1392         return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep};
1393       if (isa<arith::AndIOp>(def))
1394         return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep};
1395       if (isa<arith::OrIOp>(def))
1396         return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep};
1397       if (isa<arith::XOrIOp>(def))
1398         return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep};
1399       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1400         return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep};
1401       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1402         return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep};
1403       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1404         return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep};
1405       if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
1406         if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1407             ci.getPredicate() == arith::CmpIPredicate::sle &&
1408             ci.getPredicate() == arith::CmpIPredicate::sge &&
1409             ci.getPredicate() == arith::CmpIPredicate::ule &&
1410             ci.getPredicate() == arith::CmpIPredicate::uge) {
1411           // We can not sparsify comparison with equal, this is because 0 <= 0
1412           // yields true, and thus densifies the result.
1413           return {std::nullopt, false};
1414         }
1415 
1416         auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
1417                         ci.getPredicateAttr());
1418         return {e, hasSpDep};
1419       }
1420       if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
1421         if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1422             cf.getPredicate() == arith::CmpFPredicate::OGE &&
1423             cf.getPredicate() == arith::CmpFPredicate::OLE &&
1424             cf.getPredicate() == arith::CmpFPredicate::ONE &&
1425             cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1426             cf.getPredicate() == arith::CmpFPredicate::UGE &&
1427             cf.getPredicate() == arith::CmpFPredicate::ULE &&
1428             cf.getPredicate() == arith::CmpFPredicate::ORD &&
1429             cf.getPredicate() == arith::CmpFPredicate::UNO) {
1430           // We can not sparsify comparison with equal, this is because 0 <= 0
1431           // yields true, and thus densifies the result.
1432           return {std::nullopt, false};
1433         }
1434         auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
1435                         cf.getPredicateAttr());
1436         return {e, hasSpDep};
1437       }
1438       if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1439         if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
1440             (binop.getLeftIdentity() ||
1441              isAdmissibleBranch(binop, binop.getLeftRegion())) &&
1442             (binop.getRightIdentity() ||
1443              isAdmissibleBranch(binop, binop.getRightRegion())))
1444           return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep};
1445       }
1446     }
1447   }
1448   // Construct ternary operations if subexpressions can be built.
1449   if (def->getNumOperands() == 3) {
1450     const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1451     const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1452     const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
1453     bool hasSpDep = xDepSp || yDepSp || zDepSp;
1454     if (x.has_value() && y.has_value() && z.has_value()) {
1455       const ExprId e0 = *x;
1456       const ExprId e1 = *y;
1457       if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1458         if (isAdmissibleBranch(redop, redop.getRegion()))
1459           return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
1460       }
1461     }
1462   }
1463 
1464   // If we reach here, we are dealing with an operation that is not currently
1465   // sparsifiable. We can still generate code for it if all its operands only
1466   // have dense dependencies (i.e., all the values are loaded from dense
1467   // tensors).
1468   if (def->getNumResults() != 1) // only handle single result operation.
1469     return {std::nullopt, false};
1470 
1471   SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
1472   // Builds all the sub-expressions
1473   for (Value operand : def->getOperands())
1474     subExp.push_back(buildTensorExp(op, operand));
1475 
1476   if (llvm::all_of(subExp,
1477                    [](auto e) { return e.first.has_value() && !e.second; })) {
1478     // All the subexpressions can be built and has *no* sparse dependencies.
1479     if (subExp.size() == 2) {
1480       auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1481                       *subExp[1].first, def);
1482       return {e, false};
1483     }
1484     if (subExp.size() == 1) {
1485       auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1486                       detail::kInvalidId, def);
1487       return {e, false};
1488     }
1489   }
1490   // Cannot build.
1491   return {std::nullopt, false};
1492 }
1493 
1494 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1495                            ValueRange vals) {
1496   // Make a clone of overlap region.
1497   Region tmpRegion;
1498   IRMapping mapper;
1499   region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1500   Block &clonedBlock = tmpRegion.front();
1501   YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1502   // Merge cloned block and return yield value.
1503   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1504   rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
1505   Value val = clonedYield.getResult();
1506   rewriter.eraseOp(clonedYield);
1507   rewriter.eraseOp(placeholder);
1508   return val;
1509 }
1510 
1511 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
1512                                Operation *op, Value v0) {
1513   if (!v0)
1514     // Empty input value must be propagated.
1515     return Value();
1516   UnaryOp unop = cast<UnaryOp>(op);
1517   Region &presentRegion = unop.getPresentRegion();
1518   if (presentRegion.empty())
1519     // Uninitialized Value() will be interpreted as missing data in the
1520     // output.
1521     return Value();
1522   return insertYieldOp(rewriter, loc, presentRegion, {v0});
1523 }
1524 
1525 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
1526                                 Operation *op, Value v0, Value v1) {
1527   if (!v0 || !v1)
1528     // Empty input values must be propagated.
1529     return Value();
1530   BinaryOp binop = cast<BinaryOp>(op);
1531   Region &overlapRegion = binop.getOverlapRegion();
1532   if (overlapRegion.empty())
1533     // Uninitialized Value() will be interpreted as missing data in the
1534     // output.
1535     return Value();
1536   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1537 }
1538 
1539 Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
1540                        Value v1) const {
1541   const auto &expr = exp(e);
1542   switch (expr.kind) {
1543   // Leaf.
1544   case TensorExp::Kind::kTensor:
1545   case TensorExp::Kind::kInvariant:
1546   case TensorExp::Kind::kLoopVar:
1547   case TensorExp::Kind::kSynZero:
1548     llvm_unreachable("unexpected non-op");
1549   // Unary operations.
1550   case TensorExp::Kind::kAbsF:
1551     return rewriter.create<math::AbsFOp>(loc, v0);
1552   case TensorExp::Kind::kAbsC: {
1553     auto type = cast<ComplexType>(v0.getType());
1554     auto eltType = cast<FloatType>(type.getElementType());
1555     return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1556   }
1557   case TensorExp::Kind::kAbsI:
1558     return rewriter.create<math::AbsIOp>(loc, v0);
1559   case TensorExp::Kind::kCeilF:
1560     return rewriter.create<math::CeilOp>(loc, v0);
1561   case TensorExp::Kind::kFloorF:
1562     return rewriter.create<math::FloorOp>(loc, v0);
1563   case TensorExp::Kind::kSqrtF:
1564     return rewriter.create<math::SqrtOp>(loc, v0);
1565   case TensorExp::Kind::kSqrtC:
1566     return rewriter.create<complex::SqrtOp>(loc, v0);
1567   case TensorExp::Kind::kExpm1F:
1568     return rewriter.create<math::ExpM1Op>(loc, v0);
1569   case TensorExp::Kind::kExpm1C:
1570     return rewriter.create<complex::Expm1Op>(loc, v0);
1571   case TensorExp::Kind::kLog1pF:
1572     return rewriter.create<math::Log1pOp>(loc, v0);
1573   case TensorExp::Kind::kLog1pC:
1574     return rewriter.create<complex::Log1pOp>(loc, v0);
1575   case TensorExp::Kind::kSinF:
1576     return rewriter.create<math::SinOp>(loc, v0);
1577   case TensorExp::Kind::kSinC:
1578     return rewriter.create<complex::SinOp>(loc, v0);
1579   case TensorExp::Kind::kTanhF:
1580     return rewriter.create<math::TanhOp>(loc, v0);
1581   case TensorExp::Kind::kTanhC:
1582     return rewriter.create<complex::TanhOp>(loc, v0);
1583   case TensorExp::Kind::kNegF:
1584     return rewriter.create<arith::NegFOp>(loc, v0);
1585   case TensorExp::Kind::kNegC:
1586     return rewriter.create<complex::NegOp>(loc, v0);
1587   case TensorExp::Kind::kNegI: // no negi in std
1588     return rewriter.create<arith::SubIOp>(
1589         loc,
1590         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1591                                            rewriter.getZeroAttr(v0.getType())),
1592         v0);
1593   case TensorExp::Kind::kTruncF:
1594     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1595   case TensorExp::Kind::kExtF:
1596     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1597   case TensorExp::Kind::kCastFS:
1598     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1599   case TensorExp::Kind::kCastFU:
1600     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1601   case TensorExp::Kind::kCastSF:
1602     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1603   case TensorExp::Kind::kCastUF:
1604     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1605   case TensorExp::Kind::kCastS:
1606     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1607   case TensorExp::Kind::kCastU:
1608     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1609   case TensorExp::Kind::kCastIdx:
1610     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1611   case TensorExp::Kind::kTruncI:
1612     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1613   case TensorExp::Kind::kCIm: {
1614     auto type = cast<ComplexType>(v0.getType());
1615     auto eltType = cast<FloatType>(type.getElementType());
1616     return rewriter.create<complex::ImOp>(loc, eltType, v0);
1617   }
1618   case TensorExp::Kind::kCRe: {
1619     auto type = cast<ComplexType>(v0.getType());
1620     auto eltType = cast<FloatType>(type.getElementType());
1621     return rewriter.create<complex::ReOp>(loc, eltType, v0);
1622   }
1623   case TensorExp::Kind::kBitCast:
1624     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1625   // Binary operations.
1626   case TensorExp::Kind::kMulF:
1627     return rewriter.create<arith::MulFOp>(loc, v0, v1);
1628   case TensorExp::Kind::kMulC:
1629     return rewriter.create<complex::MulOp>(loc, v0, v1);
1630   case TensorExp::Kind::kMulI:
1631     return rewriter.create<arith::MulIOp>(loc, v0, v1);
1632   case TensorExp::Kind::kDivF:
1633     return rewriter.create<arith::DivFOp>(loc, v0, v1);
1634   case TensorExp::Kind::kDivC:
1635     return rewriter.create<complex::DivOp>(loc, v0, v1);
1636   case TensorExp::Kind::kDivS:
1637     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1638   case TensorExp::Kind::kDivU:
1639     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1640   case TensorExp::Kind::kAddF:
1641     return rewriter.create<arith::AddFOp>(loc, v0, v1);
1642   case TensorExp::Kind::kAddC:
1643     return rewriter.create<complex::AddOp>(loc, v0, v1);
1644   case TensorExp::Kind::kAddI:
1645     return rewriter.create<arith::AddIOp>(loc, v0, v1);
1646   case TensorExp::Kind::kSubF:
1647     return rewriter.create<arith::SubFOp>(loc, v0, v1);
1648   case TensorExp::Kind::kSubC:
1649     return rewriter.create<complex::SubOp>(loc, v0, v1);
1650   case TensorExp::Kind::kSubI:
1651     return rewriter.create<arith::SubIOp>(loc, v0, v1);
1652   case TensorExp::Kind::kAndI:
1653     return rewriter.create<arith::AndIOp>(loc, v0, v1);
1654   case TensorExp::Kind::kOrI:
1655     return rewriter.create<arith::OrIOp>(loc, v0, v1);
1656   case TensorExp::Kind::kXorI:
1657     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1658   case TensorExp::Kind::kShrS:
1659     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1660   case TensorExp::Kind::kShrU:
1661     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1662   case TensorExp::Kind::kShlI:
1663     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1664   case TensorExp::Kind::kCmpI: {
1665     auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr);
1666     return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1);
1667   }
1668   case TensorExp::Kind::kCmpF: {
1669     auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr);
1670     return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1);
1671   }
1672   case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
1673     return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(),
1674                          {v0});
1675   case TensorExp::Kind::kUnary:
1676     return buildUnaryPresent(rewriter, loc, expr.op, v0);
1677   case TensorExp::Kind::kSelect:
1678     return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
1679                          {v0});
1680   case TensorExp::Kind::kBinary:
1681     return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
1682   case TensorExp::Kind::kReduce: {
1683     ReduceOp redOp = cast<ReduceOp>(expr.op);
1684     return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1685   }
1686   case TensorExp::Kind::kDenseOp: {
1687     Operation *actualOp = expr.op;
1688     IRMapping mapping;
1689     mapping.map(actualOp->getOperand(0), v0);
1690     if (actualOp->getNumOperands() == 2)
1691       mapping.map(actualOp->getOperand(1), v1);
1692     return rewriter.clone(*actualOp, mapping)->getResult(0);
1693   }
1694   }
1695   llvm_unreachable("unexpected expression kind in build");
1696 }
1697 
1698 } // namespace sparse_tensor
1699 } // namespace mlir
1700