xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (revision d82e93e7f129d9e8b72570efdf4a15d6ec3d4336)
1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 #include <optional>
18 
19 namespace mlir {
20 namespace sparse_tensor {
21 
22 enum class ExpArity {
23   kNullary,
24   kUnary,
25   kBinary,
26 };
27 
28 static ExpArity getExpArity(TensorExp::Kind k) {
29   switch (k) {
30   // Leaf.
31   case TensorExp::Kind::kTensor:
32   case TensorExp::Kind::kInvariant:
33   case TensorExp::Kind::kLoopVar:
34   case TensorExp::Kind::kSynZero:
35     return ExpArity::kNullary;
36   case TensorExp::Kind::kAbsF:
37   case TensorExp::Kind::kAbsC:
38   case TensorExp::Kind::kAbsI:
39   case TensorExp::Kind::kCeilF:
40   case TensorExp::Kind::kFloorF:
41   case TensorExp::Kind::kSqrtF:
42   case TensorExp::Kind::kSqrtC:
43   case TensorExp::Kind::kExpm1F:
44   case TensorExp::Kind::kExpm1C:
45   case TensorExp::Kind::kLog1pF:
46   case TensorExp::Kind::kLog1pC:
47   case TensorExp::Kind::kSinF:
48   case TensorExp::Kind::kSinC:
49   case TensorExp::Kind::kTanhF:
50   case TensorExp::Kind::kTanhC:
51   case TensorExp::Kind::kTruncF:
52   case TensorExp::Kind::kExtF:
53   case TensorExp::Kind::kCastFS:
54   case TensorExp::Kind::kCastFU:
55   case TensorExp::Kind::kCastSF:
56   case TensorExp::Kind::kCastUF:
57   case TensorExp::Kind::kCastS:
58   case TensorExp::Kind::kCastU:
59   case TensorExp::Kind::kCastIdx:
60   case TensorExp::Kind::kTruncI:
61   case TensorExp::Kind::kCIm:
62   case TensorExp::Kind::kCRe:
63   case TensorExp::Kind::kBitCast:
64   case TensorExp::Kind::kBinaryBranch:
65   case TensorExp::Kind::kUnary:
66   case TensorExp::Kind::kSelect:
67   case TensorExp::Kind::kNegF:
68   case TensorExp::Kind::kNegC:
69   case TensorExp::Kind::kNegI:
70     return ExpArity::kUnary;
71   // Binary operations.
72   case TensorExp::Kind::kDivF:
73   case TensorExp::Kind::kDivC:
74   case TensorExp::Kind::kDivS:
75   case TensorExp::Kind::kDivU:
76   case TensorExp::Kind::kShrS:
77   case TensorExp::Kind::kShrU:
78   case TensorExp::Kind::kShlI:
79   case TensorExp::Kind::kMulF:
80   case TensorExp::Kind::kMulC:
81   case TensorExp::Kind::kMulI:
82   case TensorExp::Kind::kAndI:
83   case TensorExp::Kind::kAddF:
84   case TensorExp::Kind::kAddC:
85   case TensorExp::Kind::kAddI:
86   case TensorExp::Kind::kOrI:
87   case TensorExp::Kind::kXorI:
88   case TensorExp::Kind::kBinary:
89   case TensorExp::Kind::kReduce:
90   case TensorExp::Kind::kSubF:
91   case TensorExp::Kind::kSubC:
92   case TensorExp::Kind::kSubI:
93   case TensorExp::Kind::kCmpF:
94   case TensorExp::Kind::kCmpI:
95   case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands
96     return ExpArity::kBinary;
97   }
98   llvm_unreachable("unexpected kind");
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // Constructors.
103 //===----------------------------------------------------------------------===//
104 
105 TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
106                      Operation *o, Attribute a)
107     : kind(k), val(v), op(o) {
108   switch (kind) {
109   // Leaf.
110   case TensorExp::Kind::kTensor:
111     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
112     tensor = x;
113     return;
114   case TensorExp::Kind::kSynZero:
115     assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o);
116     return;
117   case TensorExp::Kind::kInvariant:
118     assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o);
119     return;
120   case TensorExp::Kind::kLoopVar:
121     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
122     loop = x;
123     return;
124   // Unary operations.
125   case TensorExp::Kind::kAbsF:
126   case TensorExp::Kind::kAbsC:
127   case TensorExp::Kind::kAbsI:
128   case TensorExp::Kind::kCeilF:
129   case TensorExp::Kind::kFloorF:
130   case TensorExp::Kind::kSqrtF:
131   case TensorExp::Kind::kSqrtC:
132   case TensorExp::Kind::kExpm1F:
133   case TensorExp::Kind::kExpm1C:
134   case TensorExp::Kind::kLog1pF:
135   case TensorExp::Kind::kLog1pC:
136   case TensorExp::Kind::kSinF:
137   case TensorExp::Kind::kSinC:
138   case TensorExp::Kind::kTanhF:
139   case TensorExp::Kind::kTanhC:
140   case TensorExp::Kind::kNegF:
141   case TensorExp::Kind::kNegC:
142   case TensorExp::Kind::kNegI:
143   case TensorExp::Kind::kCIm:
144   case TensorExp::Kind::kCRe:
145     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
146     children.e0 = x;
147     children.e1 = y;
148     return;
149   case TensorExp::Kind::kTruncF:
150   case TensorExp::Kind::kExtF:
151   case TensorExp::Kind::kCastFS:
152   case TensorExp::Kind::kCastFU:
153   case TensorExp::Kind::kCastSF:
154   case TensorExp::Kind::kCastUF:
155   case TensorExp::Kind::kCastS:
156   case TensorExp::Kind::kCastU:
157   case TensorExp::Kind::kCastIdx:
158   case TensorExp::Kind::kTruncI:
159   case TensorExp::Kind::kBitCast:
160     assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o);
161     children.e0 = x;
162     children.e1 = y;
163     return;
164   case TensorExp::Kind::kBinaryBranch:
165   case TensorExp::Kind::kSelect:
166     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o);
167     children.e0 = x;
168     children.e1 = y;
169     return;
170   case TensorExp::Kind::kUnary:
171     // No assertion on y can be made, as the branching paths involve both
172     // a unary (`mapSet`) and binary (`disjSet`) pathway.
173     assert(x != detail::kInvalidId && !v && o);
174     children.e0 = x;
175     children.e1 = y;
176     return;
177   // Binary operations.
178   case TensorExp::Kind::kMulF:
179   case TensorExp::Kind::kMulC:
180   case TensorExp::Kind::kMulI:
181   case TensorExp::Kind::kDivF:
182   case TensorExp::Kind::kDivC:
183   case TensorExp::Kind::kDivS:
184   case TensorExp::Kind::kDivU:
185   case TensorExp::Kind::kAddF:
186   case TensorExp::Kind::kAddC:
187   case TensorExp::Kind::kAddI:
188   case TensorExp::Kind::kSubF:
189   case TensorExp::Kind::kSubC:
190   case TensorExp::Kind::kSubI:
191   case TensorExp::Kind::kAndI:
192   case TensorExp::Kind::kOrI:
193   case TensorExp::Kind::kXorI:
194   case TensorExp::Kind::kShrS:
195   case TensorExp::Kind::kShrU:
196   case TensorExp::Kind::kShlI:
197     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
198     children.e0 = x;
199     children.e1 = y;
200     return;
201   case TensorExp::Kind::kCmpF:
202   case TensorExp::Kind::kCmpI:
203     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
204     attr = a;
205     children.e0 = x;
206     children.e1 = y;
207     return;
208   case TensorExp::Kind::kBinary:
209   case TensorExp::Kind::kReduce:
210     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o);
211     children.e0 = x;
212     children.e1 = y;
213     return;
214   case TensorExp::Kind::kDenseOp:
215     assert(x != detail::kInvalidId && !v && o);
216     children.e0 = x;
217     children.e1 = y;
218     return;
219   }
220   llvm_unreachable("unexpected kind");
221 }
222 
223 Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
224                unsigned maxLvlRank)
225     : outTensor(numInputOutputTensors - 1),
226       syntheticTensor(numInputOutputTensors),
227       numTensors(numInputOutputTensors + 1), numLoops(numLoops),
228       hasSparseOut(false),
229       lvlTypes(numTensors,
230                std::vector<LevelType>(numLoops, LevelFormat::Undef)),
231       loopToLvl(numTensors,
232                 std::vector<std::optional<Level>>(numLoops, std::nullopt)),
233       lvlToLoop(numTensors,
234                 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
235       loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>(
236                                          numTensors, std::nullopt)),
237       levelToDependentLoop(numTensors,
238                            std::vector<std::vector<LoopCoeffPair>>(
239                                maxLvlRank, std::vector<LoopCoeffPair>())),
240       loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
241 
242 //===----------------------------------------------------------------------===//
243 // Lattice methods.
244 //===----------------------------------------------------------------------===//
245 
246 ExprId Merger::addTensorExp(TensorId t) {
247   assert(isValidTensorId(t));
248   const ExprId eNew(tensorExps.size());
249   tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId,
250                           Value(), nullptr, nullptr);
251   return eNew;
252 }
253 
254 ExprId Merger::addLoopVarExp(LoopId i) {
255   assert(isValidLoopId(i));
256   const ExprId eNew(tensorExps.size());
257   tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId,
258                           Value(), nullptr, nullptr);
259   return eNew;
260 }
261 
262 ExprId Merger::addInvariantExp(Value v) {
263   const ExprId eNew(tensorExps.size());
264   tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId,
265                           detail::kInvalidId, v, nullptr, nullptr);
266   return eNew;
267 }
268 
269 ExprId Merger::addSynZeroExp() {
270   const ExprId eNew(tensorExps.size());
271   tensorExps.emplace_back(TensorExp::Kind::kSynZero, detail::kInvalidId,
272                           detail::kInvalidId, Value(), nullptr, nullptr);
273   return eNew;
274 }
275 
276 ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op,
277                       Attribute attr) {
278   assert(k > TensorExp::Kind::kLoopVar);
279   const ExprId eNew(tensorExps.size());
280   tensorExps.emplace_back(k, e0, e1, Value(), op, attr);
281   return eNew;
282 }
283 
284 ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op,
285                       Attribute attr) {
286   assert(k > TensorExp::Kind::kLoopVar);
287   const ExprId eNew(tensorExps.size());
288   tensorExps.emplace_back(k, e, detail::kInvalidId, v, op, attr);
289   return eNew;
290 }
291 
292 LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {
293   const LatPointId pNew(latPoints.size());
294   const unsigned size = numLoops * numTensors;
295   const TensorLoopId b = makeTensorLoopId(t, i);
296   latPoints.emplace_back(size, e);
297   latPoints[pNew].bits.set(b);
298   return pNew;
299 }
300 
301 LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
302   assert(bits.size() == numLoops * numTensors);
303   const LatPointId pNew(latPoints.size());
304   latPoints.emplace_back(bits, e);
305   return pNew;
306 }
307 
308 LatSetId Merger::addSet() {
309   const LatSetId sNew(latSets.size());
310   latSets.emplace_back();
311   return sNew;
312 }
313 
314 LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1,
315                            Operation *op) {
316   TensorExp::Kind kind = exp(e).kind;
317   Attribute attr = exp(e).attr;
318   const LatPointId pNew(latPoints.size());
319   const auto &point0 = lat(p0);
320   const auto &point1 = lat(p1);
321   BitVector bits(point0.bits);
322   bits |= point1.bits;
323   const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr);
324   latPoints.emplace_back(bits, ne);
325   return pNew;
326 }
327 
328 LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
329   const LatSetId sNew = addSet();
330   auto &setNew = latSets[sNew];
331   for (const LatPointId p0 : set(s0))
332     for (const LatPointId p1 : set(s1))
333       setNew.push_back(conjLat(e, p0, p1, op));
334   return sNew;
335 }
336 
337 LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
338   const LatSetId sNew = conjSet(e, s0, s1, op);
339   TensorExp::Kind kind = exp(e).kind;
340 
341   // Followed by all in s0.
342   latSets[sNew].append(latSets[s0]);
343   // Map binary 0-y to unary -y.
344   // TODO: move this if-else logic into buildLattices
345   if (kind == TensorExp::Kind::kSubF)
346     s1 = mapSet(TensorExp::Kind::kNegF, s1);
347   else if (kind == TensorExp::Kind::kSubC)
348     s1 = mapSet(TensorExp::Kind::kNegC, s1);
349   else if (kind == TensorExp::Kind::kSubI)
350     s1 = mapSet(TensorExp::Kind::kNegI, s1);
351   // Followed by all in s1.
352   latSets[sNew].append(latSets[s1]);
353   return sNew;
354 }
355 
356 LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) {
357   assert(exp(e).kind == TensorExp::Kind::kCmpI ||
358          exp(e).kind == TensorExp::Kind::kCmpF);
359   const LatSetId sNew = conjSet(e, s0, s1, nullptr);
360 
361   ExprId e0 = exp(e).children.e0;
362   ExprId e1 = exp(e).children.e1;
363   if (exp(e0).kind == TensorExp::Kind::kSynZero ||
364       exp(e1).kind == TensorExp::Kind::kSynZero) {
365     // lhs and rhs can't be synthetic zero at the same time.
366     assert(exp(e0).kind != exp(e1).kind);
367     // If one of the operands has already been assigned to zero (the
368     // element is absent in the corresponding operand), then we do not
369     // need to build disjunctive set for it.
370     return sNew;
371   }
372 
373   auto lhsSet = mapBinWithSynZeroSet(e, s0, false);
374   auto rhsSet = mapBinWithSynZeroSet(e, s1, true);
375   latSets[sNew].append(latSets[lhsSet]);
376   latSets[sNew].append(latSets[rhsSet]);
377   return sNew;
378 }
379 
380 LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
381                           bool includeLeft, TensorExp::Kind ltrans,
382                           Operation *opleft, bool includeRight,
383                           TensorExp::Kind rtrans, Operation *opright) {
384   const LatSetId sNew = conjSet(e, s0, s1, orig);
385   // Left Region.
386   if (includeLeft) {
387     if (opleft)
388       s0 = mapSet(ltrans, s0, Value(), opleft);
389     latSets[sNew].append(latSets[s0]);
390   }
391   // Right Region.
392   if (includeRight) {
393     if (opright)
394       s1 = mapSet(rtrans, s1, Value(), opright);
395     latSets[sNew].append(latSets[s1]);
396   }
397   return sNew;
398 }
399 
400 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
401                         Operation *op) {
402   assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
403          TensorExp::Kind::kDenseOp == kind);
404   const LatSetId sNew = addSet();
405   auto &setNew = latSets[sNew];
406   for (const LatPointId p : set(s0)) {
407     const auto &point = latPoints[p];
408     setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op)));
409   }
410   return sNew;
411 }
412 
413 LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) {
414   TensorExp::Kind kind = exp(e).kind;
415   Attribute a = exp(e).attr;
416   assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI);
417   // Must be a binary operation.
418   const LatSetId sNew = addSet();
419   auto &setNew = latSets[sNew];
420   const ExprId zeroExp = addSynZeroExp();
421   for (const LatPointId p : set(s0)) {
422     const auto &point = latPoints[p];
423     ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a)
424                             : addExp(kind, point.exp, zeroExp, nullptr, a);
425     setNew.push_back(addLat(point.bits, newExp));
426   }
427   return sNew;
428 }
429 
430 LatSetId Merger::optimizeSet(LatSetId s0) {
431   const LatSetId sNew = addSet();
432   auto &setNew = latSets[sNew];
433   const auto &set0 = set(s0);
434   assert(!set0.empty());
435   const LatPointId p0 = set0[0];
436   for (const LatPointId p1 : set0) {
437     bool add = true;
438     if (p0 != p1) {
439       // Check whether this is a straightforward copy.
440       if (expIsTensor(latPoints[p1].exp, outTensor))
441         continue;
442       // Check whether this conjunction is already covered.
443       for (const LatPointId p2 : setNew) {
444         assert(!latGT(p1, p2)); // Lj => Li would be bad
445         if (onlyDenseDiff(p2, p1)) {
446           add = false;
447           break;
448         }
449       }
450       assert(!add || latGT(p0, p1));
451     }
452     if (add)
453       setNew.push_back(p1);
454   }
455   for (const LatPointId p : setNew)
456     latPoints[p].simple = simplifyCond(sNew, p);
457   return sNew;
458 }
459 
460 BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
461   // First determine if this lattice point is a *singleton*, i.e.,
462   // the last point in a lattice, no other is less than this one.
463   bool isSingleton = true;
464   for (const LatPointId p1 : set(s0)) {
465     if (p0 != p1 && latGT(p0, p1)) {
466       isSingleton = false;
467       break;
468     }
469   }
470 
471   BitVector simple(latPoints[p0].bits);
472   bool reset = isSingleton && hasAnySparse(simple);
473   const TensorLoopId be = simple.size();
474   TensorLoopId offset = 0; // relative to the end
475   if (!reset)
476     // Starts resetting from a dense level, so that the first bit (if kept)
477     // is not undefined level-type.
478     for (unsigned b = 0; b < be; b++) {
479       if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
480         offset = be - b - 1; // relative to the end
481         break;
482       }
483     }
484 
485   // Now apply the two basic rules. We also iterate the bits reversely to always
486   // keep the rightmost bit (which could possibly be a synthetic tensor).
487   for (unsigned b = be - 1 - offset, i = 0; i < be;
488        b = b == 0 ? be - 1 : b - 1, i++) {
489     // Slice on dense level has `locate` property as well, and can be optimized.
490     if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
491       const auto lt = getLvlType(b);
492       if (!lt.hasSparseSemantic()) {
493         if (reset)
494           simple.reset(b);
495         reset = true;
496       }
497     }
498   }
499   return simple;
500 }
501 
502 bool Merger::latGT(LatPointId i, LatPointId j) const {
503   const BitVector &bitsi = lat(i).bits;
504   const BitVector &bitsj = lat(j).bits;
505   assert(bitsi.size() == bitsj.size());
506   if (bitsi.count() > bitsj.count()) {
507     for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
508       if (bitsj[b] && !bitsi[b])
509         return false;
510     return true;
511   }
512   return false;
513 }
514 
515 bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
516   BitVector tmp(latPoints[j].bits);
517   tmp ^= latPoints[i].bits;
518   return !hasAnySparse(tmp);
519 }
520 
521 bool Merger::expContainsTensor(ExprId e, TensorId t) const {
522   const auto &expr = exp(e);
523   // First we check `expIsTensor`.
524   if (expr.kind == TensorExp::Kind::kTensor)
525     return expr.tensor == t;
526 
527   switch (getExpArity(expr.kind)) {
528   case ExpArity::kNullary:
529     return false;
530   case ExpArity::kUnary: {
531     const ExprId e0 = expr.children.e0;
532     return expContainsTensor(e0, t);
533   }
534   case ExpArity::kBinary: {
535     const ExprId e0 = expr.children.e0;
536     const ExprId e1 = expr.children.e1;
537     return expContainsTensor(e0, t) || expContainsTensor(e1, t);
538   }
539   }
540   llvm_unreachable("unexpected arity");
541 }
542 
543 bool Merger::hasNegateOnOut(ExprId e) const {
544   const auto &expr = exp(e);
545   switch (expr.kind) {
546   case TensorExp::Kind::kNegF:
547   case TensorExp::Kind::kNegC:
548   case TensorExp::Kind::kNegI:
549     return expContainsTensor(expr.children.e0, outTensor);
550   case TensorExp::Kind::kSubF:
551   case TensorExp::Kind::kSubC:
552   case TensorExp::Kind::kSubI:
553     return expContainsTensor(expr.children.e1, outTensor) ||
554            hasNegateOnOut(expr.children.e0);
555   case TensorExp::Kind::kDenseOp: {
556     bool lhsNeg = hasNegateOnOut(expr.children.e0);
557     if (!lhsNeg && expr.children.e1 != detail::kInvalidId)
558       return hasNegateOnOut(expr.children.e1);
559     return lhsNeg;
560   }
561   default: {
562     switch (getExpArity(expr.kind)) {
563     case ExpArity::kNullary:
564       return false;
565     case ExpArity::kUnary:
566       return hasNegateOnOut(expr.children.e0);
567     case ExpArity::kBinary:
568       return hasNegateOnOut(expr.children.e0) ||
569              hasNegateOnOut(expr.children.e1);
570     }
571   }
572   }
573   llvm_unreachable("unexpected kind");
574 }
575 
576 bool Merger::isSingleCondition(TensorId t, ExprId e) const {
577   assert(isValidTensorId(t));
578   const auto &expr = exp(e);
579   switch (expr.kind) {
580   // Leaf.
581   case TensorExp::Kind::kTensor:
582     return expr.tensor == t;
583   case TensorExp::Kind::kInvariant:
584   case TensorExp::Kind::kLoopVar:
585   case TensorExp::Kind::kSynZero:
586     return false;
587   // Unary operations.
588   case TensorExp::Kind::kAbsF:
589   case TensorExp::Kind::kAbsC:
590   case TensorExp::Kind::kAbsI:
591   case TensorExp::Kind::kCeilF:
592   case TensorExp::Kind::kFloorF:
593   case TensorExp::Kind::kSqrtF:
594   case TensorExp::Kind::kSqrtC:
595   case TensorExp::Kind::kExpm1F:
596   case TensorExp::Kind::kExpm1C:
597   case TensorExp::Kind::kLog1pF:
598   case TensorExp::Kind::kLog1pC:
599   case TensorExp::Kind::kSinF:
600   case TensorExp::Kind::kSinC:
601   case TensorExp::Kind::kTanhF:
602   case TensorExp::Kind::kTanhC:
603   case TensorExp::Kind::kNegF:
604   case TensorExp::Kind::kNegC:
605   case TensorExp::Kind::kNegI:
606   case TensorExp::Kind::kTruncF:
607   case TensorExp::Kind::kExtF:
608   case TensorExp::Kind::kCastFS:
609   case TensorExp::Kind::kCastFU:
610   case TensorExp::Kind::kCastSF:
611   case TensorExp::Kind::kCastUF:
612   case TensorExp::Kind::kCastS:
613   case TensorExp::Kind::kCastU:
614   case TensorExp::Kind::kCastIdx:
615   case TensorExp::Kind::kTruncI:
616   case TensorExp::Kind::kCIm:
617   case TensorExp::Kind::kCRe:
618   case TensorExp::Kind::kBitCast:
619   case TensorExp::Kind::kUnary:
620     return isSingleCondition(t, expr.children.e0);
621   case TensorExp::Kind::kBinaryBranch:
622   case TensorExp::Kind::kSelect:
623     return false;
624   // Binary operations.
625   case TensorExp::Kind::kDivF: // note: x / c only
626   case TensorExp::Kind::kDivC:
627   case TensorExp::Kind::kDivS:
628   case TensorExp::Kind::kDivU:
629     assert(!maybeZero(expr.children.e1));
630     return isSingleCondition(t, expr.children.e0);
631   case TensorExp::Kind::kShrS: // note: x >> inv only
632   case TensorExp::Kind::kShrU:
633   case TensorExp::Kind::kShlI:
634     assert(isInvariant(expr.children.e1));
635     return isSingleCondition(t, expr.children.e0);
636   case TensorExp::Kind::kMulF:
637   case TensorExp::Kind::kMulC:
638   case TensorExp::Kind::kMulI:
639   case TensorExp::Kind::kAndI:
640   case TensorExp::Kind::kReduce:
641     if (isSingleCondition(t, expr.children.e0))
642       return isSingleCondition(t, expr.children.e1) ||
643              isInvariant(expr.children.e1);
644     if (isSingleCondition(t, expr.children.e1))
645       return isInvariant(expr.children.e0);
646     return false;
647   case TensorExp::Kind::kAddF:
648   case TensorExp::Kind::kAddC:
649   case TensorExp::Kind::kAddI:
650     return isSingleCondition(t, expr.children.e0) &&
651            isSingleCondition(t, expr.children.e1);
652   case TensorExp::Kind::kSubF:
653   case TensorExp::Kind::kSubC:
654   case TensorExp::Kind::kSubI:
655   case TensorExp::Kind::kOrI:
656   case TensorExp::Kind::kXorI:
657   case TensorExp::Kind::kCmpF:
658   case TensorExp::Kind::kCmpI:
659   case TensorExp::Kind::kBinary:
660     return false;
661   case TensorExp::Kind::kDenseOp:
662     // Since Merger guarantees all the operands of the kDenseOp to be dense, the
663     // operation must be single-condition.
664     return true;
665   }
666   llvm_unreachable("unexpected kind");
667 }
668 
669 bool Merger::hasAnySparse(const BitVector &bits) const {
670   for (TensorLoopId b : bits.set_bits()) {
671     const auto lt = getLvlType(b);
672     if (lt.hasSparseSemantic())
673       return true;
674   }
675   return hasSparseIdxReduction(bits);
676 }
677 
678 bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
679   for (TensorLoopId b : bits.set_bits())
680     if (isSparseLvlWithNonTrivialIdxExp(b))
681       return true;
682   return false;
683 }
684 
685 #ifndef NDEBUG
686 
687 //===----------------------------------------------------------------------===//
688 // Print methods (for debugging).
689 //===----------------------------------------------------------------------===//
690 
691 static const char *kindToOpSymbol(TensorExp::Kind kind) {
692   switch (kind) {
693   // Leaf.
694   case TensorExp::Kind::kTensor:
695     return "tensor";
696   case TensorExp::Kind::kInvariant:
697     return "invariant";
698   case TensorExp::Kind::kLoopVar:
699     return "index";
700   case TensorExp::Kind::kSynZero:
701     return "0";
702   // Unary operations.
703   case TensorExp::Kind::kAbsF:
704   case TensorExp::Kind::kAbsC:
705   case TensorExp::Kind::kAbsI:
706     return "abs";
707   case TensorExp::Kind::kCeilF:
708     return "ceil";
709   case TensorExp::Kind::kFloorF:
710     return "floor";
711   case TensorExp::Kind::kSqrtF:
712   case TensorExp::Kind::kSqrtC:
713     return "sqrt";
714   case TensorExp::Kind::kExpm1F:
715   case TensorExp::Kind::kExpm1C:
716     return "expm1";
717   case TensorExp::Kind::kLog1pF:
718   case TensorExp::Kind::kLog1pC:
719     return "log1p";
720   case TensorExp::Kind::kSinF:
721   case TensorExp::Kind::kSinC:
722     return "sin";
723   case TensorExp::Kind::kTanhF:
724   case TensorExp::Kind::kTanhC:
725     return "tanh";
726   case TensorExp::Kind::kNegF:
727   case TensorExp::Kind::kNegC:
728   case TensorExp::Kind::kNegI:
729     return "-";
730   case TensorExp::Kind::kTruncF:
731   case TensorExp::Kind::kExtF:
732   case TensorExp::Kind::kCastFS:
733   case TensorExp::Kind::kCastFU:
734   case TensorExp::Kind::kCastSF:
735   case TensorExp::Kind::kCastUF:
736   case TensorExp::Kind::kCastS:
737   case TensorExp::Kind::kCastU:
738   case TensorExp::Kind::kCastIdx:
739   case TensorExp::Kind::kTruncI:
740   case TensorExp::Kind::kCIm:
741     return "complex.im";
742   case TensorExp::Kind::kCRe:
743     return "complex.re";
744   case TensorExp::Kind::kBitCast:
745     return "cast";
746   case TensorExp::Kind::kBinaryBranch:
747     return "binary_branch";
748   case TensorExp::Kind::kUnary:
749     return "unary";
750   case TensorExp::Kind::kSelect:
751     return "select";
752   // Binary operations.
753   case TensorExp::Kind::kMulF:
754   case TensorExp::Kind::kMulC:
755   case TensorExp::Kind::kMulI:
756     return "*";
757   case TensorExp::Kind::kDivF:
758   case TensorExp::Kind::kDivC:
759   case TensorExp::Kind::kDivS:
760   case TensorExp::Kind::kDivU:
761     return "/";
762   case TensorExp::Kind::kAddF:
763   case TensorExp::Kind::kAddC:
764   case TensorExp::Kind::kAddI:
765     return "+";
766   case TensorExp::Kind::kSubF:
767   case TensorExp::Kind::kSubC:
768   case TensorExp::Kind::kSubI:
769     return "-";
770   case TensorExp::Kind::kAndI:
771     return "&";
772   case TensorExp::Kind::kOrI:
773     return "|";
774   case TensorExp::Kind::kXorI:
775     return "^";
776   case TensorExp::Kind::kShrS:
777     return "a>>";
778   case TensorExp::Kind::kShrU:
779     return ">>";
780   case TensorExp::Kind::kShlI:
781     return "<<";
782   case TensorExp::Kind::kCmpF:
783   case TensorExp::Kind::kCmpI:
784     return "cmp";
785   case TensorExp::Kind::kBinary:
786     return "binary";
787   case TensorExp::Kind::kReduce:
788     return "reduce";
789   case TensorExp::Kind::kDenseOp:
790     return "dense";
791   }
792   llvm_unreachable("unexpected kind for symbol");
793 }
794 
795 void Merger::dumpExp(ExprId e) const {
796   const auto &expr = exp(e);
797   switch (expr.kind) {
798   // Leaf.
799   case TensorExp::Kind::kTensor:
800     if (expr.tensor == syntheticTensor)
801       llvm::dbgs() << "synthetic_";
802     else if (expr.tensor == outTensor)
803       llvm::dbgs() << "output_";
804     llvm::dbgs() << "tensor_" << expr.tensor;
805     break;
806   case TensorExp::Kind::kInvariant:
807     llvm::dbgs() << "invariant";
808     break;
809   case TensorExp::Kind::kSynZero:
810     llvm::dbgs() << "0";
811     break;
812   case TensorExp::Kind::kLoopVar:
813     llvm::dbgs() << "loopvar_" << expr.loop;
814     break;
815   // Unary operations.
816   case TensorExp::Kind::kAbsF:
817   case TensorExp::Kind::kAbsC:
818   case TensorExp::Kind::kAbsI:
819   case TensorExp::Kind::kCeilF:
820   case TensorExp::Kind::kFloorF:
821   case TensorExp::Kind::kSqrtF:
822   case TensorExp::Kind::kSqrtC:
823   case TensorExp::Kind::kExpm1F:
824   case TensorExp::Kind::kExpm1C:
825   case TensorExp::Kind::kLog1pF:
826   case TensorExp::Kind::kLog1pC:
827   case TensorExp::Kind::kSinF:
828   case TensorExp::Kind::kSinC:
829   case TensorExp::Kind::kTanhF:
830   case TensorExp::Kind::kTanhC:
831   case TensorExp::Kind::kNegF:
832   case TensorExp::Kind::kNegC:
833   case TensorExp::Kind::kNegI:
834   case TensorExp::Kind::kTruncF:
835   case TensorExp::Kind::kExtF:
836   case TensorExp::Kind::kCastFS:
837   case TensorExp::Kind::kCastFU:
838   case TensorExp::Kind::kCastSF:
839   case TensorExp::Kind::kCastUF:
840   case TensorExp::Kind::kCastS:
841   case TensorExp::Kind::kCastU:
842   case TensorExp::Kind::kCastIdx:
843   case TensorExp::Kind::kTruncI:
844   case TensorExp::Kind::kCIm:
845   case TensorExp::Kind::kCRe:
846   case TensorExp::Kind::kBitCast:
847   case TensorExp::Kind::kBinaryBranch:
848   case TensorExp::Kind::kUnary:
849   case TensorExp::Kind::kSelect:
850     llvm::dbgs() << kindToOpSymbol(expr.kind) << " ";
851     dumpExp(expr.children.e0);
852     break;
853   // Binary operations.
854   case TensorExp::Kind::kMulF:
855   case TensorExp::Kind::kMulC:
856   case TensorExp::Kind::kMulI:
857   case TensorExp::Kind::kDivF:
858   case TensorExp::Kind::kDivC:
859   case TensorExp::Kind::kDivS:
860   case TensorExp::Kind::kDivU:
861   case TensorExp::Kind::kAddF:
862   case TensorExp::Kind::kAddC:
863   case TensorExp::Kind::kAddI:
864   case TensorExp::Kind::kSubF:
865   case TensorExp::Kind::kSubC:
866   case TensorExp::Kind::kSubI:
867   case TensorExp::Kind::kAndI:
868   case TensorExp::Kind::kOrI:
869   case TensorExp::Kind::kXorI:
870   case TensorExp::Kind::kShrS:
871   case TensorExp::Kind::kShrU:
872   case TensorExp::Kind::kShlI:
873   case TensorExp::Kind::kCmpF:
874   case TensorExp::Kind::kCmpI:
875   case TensorExp::Kind::kBinary:
876   case TensorExp::Kind::kReduce:
877   case TensorExp::Kind::kDenseOp:
878     llvm::dbgs() << "(";
879     dumpExp(expr.children.e0);
880     llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
881     if (expr.attr)
882       llvm::dbgs() << "{" << expr.attr << "}";
883     if (expr.children.e1 != detail::kInvalidId) {
884       llvm::dbgs() << " ";
885       dumpExp(expr.children.e1);
886       llvm::dbgs() << ")";
887     } else {
888       assert(expr.kind == TensorExp::Kind::kDenseOp);
889     }
890     break;
891   }
892 }
893 
894 void Merger::dumpLat(LatPointId p) const {
895   const auto &point = lat(p);
896   llvm::dbgs() << "lat(";
897   dumpBits(point.bits);
898   llvm::dbgs() << " :";
899   dumpBits(point.simple);
900   llvm::dbgs() << " : ";
901   dumpExp(point.exp);
902   llvm::dbgs() << " )\n";
903 }
904 
905 void Merger::dumpSet(LatSetId s) const {
906   const auto &ss = set(s);
907   llvm::dbgs() << "{ #" << ss.size() << "\n";
908   for (const LatPointId p : ss) {
909     llvm::dbgs() << "  ";
910     dumpLat(p);
911   }
912   llvm::dbgs() << "}\n";
913 }
914 
915 void Merger::dumpBits(const BitVector &bits) const {
916   for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
917     if (bits[b]) {
918       const TensorId t = tensor(b);
919       const LoopId i = loop(b);
920       const auto lt = lvlTypes[t][i];
921       if (isLvlWithNonTrivialIdxExp(b))
922         llvm::dbgs() << " DEP_" << t << "_" << i;
923       else
924         llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt);
925     }
926   }
927 }
928 
929 #endif // NDEBUG
930 
931 //===----------------------------------------------------------------------===//
932 // Builder methods.
933 //===----------------------------------------------------------------------===//
934 
935 LatSetId Merger::buildLattices(ExprId e, LoopId i) {
936   // NOTE: The `expr` reference will be invalidated by recursive calls
937   // (and any other method that may add new expressions); therefore, the
938   // code below must make sure to copy fields of `expr` into local variables
939   // before making any recursive calls.
940   const auto &expr = exp(e);
941   const TensorExp::Kind kind = expr.kind;
942   switch (kind) {
943   // Leaf.
944   case TensorExp::Kind::kTensor:
945   case TensorExp::Kind::kInvariant:
946   case TensorExp::Kind::kSynZero:
947   case TensorExp::Kind::kLoopVar: {
948     // Either the loop-var is really used in the tensor expression, or it is
949     // set to the undefined loop-var in that level. An invariant expression,
950     // a proper index value, and a truly dynamic sparse output tensor are set
951     // to a synthetic tensor with undefined indices only to ensure the
952     // iteration space is not skipped as a result of their contents.
953     const LatSetId s = addSet();
954     TensorId t = syntheticTensor;
955     if (kind == TensorExp::Kind::kTensor) {
956       t = expr.tensor;
957       if (hasSparseOut && t == outTensor)
958         t = syntheticTensor;
959     }
960     latSets[s].push_back(addLat(t, i, e));
961     return s;
962   }
963   // Unary operations.
964   case TensorExp::Kind::kAbsF:
965   case TensorExp::Kind::kAbsC:
966   case TensorExp::Kind::kAbsI:
967   case TensorExp::Kind::kCeilF:
968   case TensorExp::Kind::kFloorF:
969   case TensorExp::Kind::kSqrtF:
970   case TensorExp::Kind::kSqrtC:
971   case TensorExp::Kind::kExpm1F:
972   case TensorExp::Kind::kExpm1C:
973   case TensorExp::Kind::kLog1pF:
974   case TensorExp::Kind::kLog1pC:
975   case TensorExp::Kind::kSinF:
976   case TensorExp::Kind::kSinC:
977   case TensorExp::Kind::kTanhF:
978   case TensorExp::Kind::kTanhC:
979   case TensorExp::Kind::kNegF:
980   case TensorExp::Kind::kNegC:
981   case TensorExp::Kind::kNegI:
982   case TensorExp::Kind::kTruncF:
983   case TensorExp::Kind::kExtF:
984   case TensorExp::Kind::kCastFS:
985   case TensorExp::Kind::kCastFU:
986   case TensorExp::Kind::kCastSF:
987   case TensorExp::Kind::kCastUF:
988   case TensorExp::Kind::kCastS:
989   case TensorExp::Kind::kCastU:
990   case TensorExp::Kind::kCastIdx:
991   case TensorExp::Kind::kTruncI:
992   case TensorExp::Kind::kCIm:
993   case TensorExp::Kind::kCRe:
994   case TensorExp::Kind::kBitCast:
995     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
996     // lattice set of the operand through the operator into a new set.
997     //
998     //  -y|!y | y |
999     //  --+---+---+
1000     //    | 0 |-y |
1001     {
1002       const ExprId e0 = expr.children.e0;
1003       const Value v = expr.val;
1004       return mapSet(kind, buildLattices(e0, i), v);
1005     }
1006   case TensorExp::Kind::kBinaryBranch:
1007   case TensorExp::Kind::kSelect:
1008     // The left or right half of a binary operation which has already
1009     // been split into separate operations for each region.
1010     {
1011       const ExprId e0 = expr.children.e0;
1012       Operation *const op = expr.op;
1013       return mapSet(kind, buildLattices(e0, i), Value(), op);
1014     }
1015   case TensorExp::Kind::kUnary:
1016     // A custom unary operation.
1017     //
1018     //  op y|    !y    |     y      |
1019     //  ----+----------+------------+
1020     //      | absent() | present(y) |
1021     {
1022       const ExprId e0 = expr.children.e0;
1023       UnaryOp unop = cast<UnaryOp>(expr.op);
1024       const LatSetId child0 = buildLattices(e0, i);
1025       Region &absentRegion = unop.getAbsentRegion();
1026       if (absentRegion.empty()) {
1027         // Simple mapping over existing values.
1028         return mapSet(kind, child0, Value(), unop);
1029       }
1030       // Use a disjunction with `unop` on the left and the absent value as an
1031       // invariant on the right.
1032       Block &absentBlock = absentRegion.front();
1033       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
1034       const Value absentVal = absentYield.getResult();
1035       const ExprId rhs = addInvariantExp(absentVal);
1036       return disjSet(e, child0, buildLattices(rhs, i), unop);
1037     }
1038   // Binary operations.
1039   case TensorExp::Kind::kMulF:
1040   case TensorExp::Kind::kMulC:
1041   case TensorExp::Kind::kMulI:
1042   case TensorExp::Kind::kAndI:
1043     // A multiplicative operation only needs to be performed
1044     // for the conjunction of sparse iteration spaces.
1045     //
1046     //  x*y|!y | y |
1047     //  ---+---+---+
1048     //  !x | 0 | 0 |
1049     //   x | 0 |x*y|
1050     //
1051     // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
1052     {
1053       const ExprId e0 = expr.children.e0;
1054       const ExprId e1 = expr.children.e1;
1055       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1056     }
1057   case TensorExp::Kind::kDivF:
1058   case TensorExp::Kind::kDivC:
1059   case TensorExp::Kind::kDivS:
1060   case TensorExp::Kind::kDivU:
1061     // A division is tricky, since 0/0, 0/c, c/0 all have
1062     // specific outcomes for floating-point and integers.
1063     // Thus, we need to traverse the full iteration space.
1064     //
1065     //  x/y|!y | y |
1066     //  ---+---+---+
1067     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
1068     //   x |x/0|x/y|  INT: x/0=exception for any x
1069     //
1070     // TODO: for now we "fixed" this by only accepting x/c cases
1071     //       during expression building, so that the conjunction
1072     //       rules applies (viz. x/c = x*(1/c) as far as lattice
1073     //       construction is concerned).
1074     {
1075       const ExprId e0 = expr.children.e0;
1076       const ExprId e1 = expr.children.e1;
1077       assert(!maybeZero(e1));
1078       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1079     }
1080   case TensorExp::Kind::kAddF:
1081   case TensorExp::Kind::kAddC:
1082   case TensorExp::Kind::kAddI:
1083   case TensorExp::Kind::kSubF:
1084   case TensorExp::Kind::kSubC:
1085   case TensorExp::Kind::kSubI:
1086   case TensorExp::Kind::kOrI:
1087   case TensorExp::Kind::kXorI:
1088     // An additive operation needs to be performed
1089     // for the disjunction of sparse iteration spaces.
1090     //
1091     //  x+y|!y | y |    x-y|!y | y |
1092     //  ---+---+---+    ---+---+---+
1093     //  !x | 0 | y |    !x | 0 |-y |
1094     //   x | x |x+y|     x | x |x-y|
1095     {
1096       const ExprId e0 = expr.children.e0;
1097       const ExprId e1 = expr.children.e1;
1098       return disjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1099     }
1100   case TensorExp::Kind::kCmpF:
1101   case TensorExp::Kind::kCmpI:
1102     // A comparison operation needs to be performed
1103     // for the disjunction of sparse iteration spaces.
1104     //
1105     //   x < y |  !y   |   y   |
1106     //  -------+-------+-------+
1107     //     !x  |   0   | 0 < y |
1108     //      x  | x < 0 | x < y |
1109     {
1110       const ExprId e0 = expr.children.e0;
1111       const ExprId e1 = expr.children.e1;
1112       return disjSetWithZero(e, buildLattices(e0, i), buildLattices(e1, i));
1113     }
1114   case TensorExp::Kind::kShrS:
1115   case TensorExp::Kind::kShrU:
1116   case TensorExp::Kind::kShlI:
1117     // A shift operation by an invariant amount (viz. tensor expressions
1118     // can only occur at the left-hand-side of the operator) can be handled
1119     // with the conjunction rule.
1120     {
1121       const ExprId e0 = expr.children.e0;
1122       const ExprId e1 = expr.children.e1;
1123       assert(isInvariant(e1));
1124       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1125     }
1126   case TensorExp::Kind::kBinary:
1127     // A custom binary operation.
1128     //
1129     //  x op y|   !y    |       y      |
1130     //  ------+---------+--------------+
1131     //    !x  |  empty  |   right(y)   |
1132     //     x  | left(x) | overlap(x,y) |
1133     {
1134       const ExprId e0 = expr.children.e0;
1135       const ExprId e1 = expr.children.e1;
1136       BinaryOp binop = cast<BinaryOp>(expr.op);
1137       const LatSetId child0 = buildLattices(e0, i);
1138       const LatSetId child1 = buildLattices(e1, i);
1139       Region &leftRegion = binop.getLeftRegion();
1140       Region &rightRegion = binop.getRightRegion();
1141       // Left Region.
1142       Operation *leftYield = nullptr;
1143       if (!leftRegion.empty()) {
1144         Block &leftBlock = leftRegion.front();
1145         leftYield = leftBlock.getTerminator();
1146       }
1147       // Right Region.
1148       Operation *rightYield = nullptr;
1149       if (!rightRegion.empty()) {
1150         Block &rightBlock = rightRegion.front();
1151         rightYield = rightBlock.getTerminator();
1152       }
1153       bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
1154       bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
1155       return combiSet(e, child0, child1, binop, includeLeft,
1156                       TensorExp::Kind::kBinaryBranch, leftYield, includeRight,
1157                       TensorExp::Kind::kBinaryBranch, rightYield);
1158     }
1159   case TensorExp::Kind::kReduce:
1160     // A custom reduce operation.
1161     {
1162       const ExprId e0 = expr.children.e0;
1163       const ExprId e1 = expr.children.e1;
1164       Operation *const op = expr.op;
1165       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1166     }
1167   case TensorExp::Kind::kDenseOp: {
1168     // It does not really matter whether we use conjunctive/disjunctive set
1169     // here, as all the operands of kDenseOp must be dense, the disjunctive set
1170     // will be optimized into conjunctive set eventually.
1171     if (expr.children.e1 == detail::kInvalidId) {
1172       const ExprId e0 = expr.children.e0;
1173       Operation *const op = expr.op;
1174       return mapSet(kind, buildLattices(e0, i), Value(), op);
1175     }
1176 
1177     const ExprId e0 = expr.children.e0;
1178     const ExprId e1 = expr.children.e1;
1179     Operation *const op = expr.op;
1180     return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1181   }
1182   }
1183   llvm_unreachable("unexpected expression kind");
1184 }
1185 
1186 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
1187   // Build the linalg semantics backward from yield.
1188   Operation *yield = op.getRegion().front().getTerminator();
1189   assert(isa<linalg::YieldOp>(yield));
1190   return buildTensorExp(op, yield->getOperand(0)).first;
1191 }
1192 
1193 /// Only returns false if we are certain this is a nonzero.
1194 bool Merger::maybeZero(ExprId e) const {
1195   const auto &expr = exp(e);
1196   if (expr.kind == TensorExp::Kind::kInvariant) {
1197     if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1198       ArrayAttr arrayAttr = c.getValue();
1199       return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1200              cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1201     }
1202     if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
1203       return c.value() == 0;
1204     if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>())
1205       return c.value().isZero();
1206   }
1207   return true;
1208 }
1209 
1210 Type Merger::inferType(ExprId e, Value src) const {
1211   // Obtain the destination type from the cast node.
1212   Type dtp = exp(e).val.getType();
1213   // Inspect source type. For vector types, apply the same
1214   // vectorization to the destination type.
1215   if (auto vtp = dyn_cast<VectorType>(src.getType()))
1216     return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
1217   return dtp;
1218 }
1219 
1220 /// Ensures that the sparsifier can generate code for expression.
1221 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
1222   // Arguments are always admissible.
1223   if (isa<BlockArgument>(v))
1224     return true;
1225   // Accept index anywhere.
1226   Operation *def = v.getDefiningOp();
1227   if (isa<linalg::IndexOp>(def))
1228     return true;
1229   // Operation defined outside branch.
1230   if (def->getBlock() != block)
1231     return def->getBlock() != op->getBlock(); // invariant?
1232   // Operation defined within branch. Anything is accepted,
1233   // as long as all subexpressions are admissible.
1234   for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
1235     if (!isAdmissibleBranchExp(op, block, def->getOperand(i)))
1236       return false;
1237   return true;
1238 }
1239 
1240 /// Ensures that the sparsifier can generate code for branch.
1241 static bool isAdmissibleBranch(Operation *op, Region &region) {
1242   if (region.empty())
1243     return true;
1244   // Build the semi-ring branch semantics backward from yield.
1245   Operation *yield = region.front().getTerminator();
1246   assert(isa<YieldOp>(yield));
1247   return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
1248 }
1249 
1250 std::pair<std::optional<ExprId>, bool>
1251 Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1252   // Recursion leaves.
1253   if (auto arg = dyn_cast<BlockArgument>(v)) {
1254     const TensorId tid = makeTensorId(arg.getArgNumber());
1255     // Any argument of the generic op that is not marked as a scalar
1256     // argument is considered a tensor, indexed by the implicit loop
1257     // bounds. This includes rank-0 tensor arguments.
1258     if (arg.getOwner()->getParentOp() == op) {
1259       OpOperand &t = op->getOpOperand(tid);
1260       bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr;
1261       if (!op.isScalar(&t))
1262         return {addTensorExp(tid), hasSpDep};
1263       v = t.get(); // get scalar value
1264     }
1265     // Any other argument (marked as scalar argument for the generic op
1266     // or belonging to an enveloping op) is considered invariant.
1267     return {addInvariantExp(v), /*hasSpDep=*/false};
1268   }
1269   // Something defined outside is invariant.
1270   Operation *def = v.getDefiningOp();
1271   if (def->getBlock() != &op.getRegion().front())
1272     return {addInvariantExp(v), /*hasSpDep=*/false};
1273   // Construct index operations.
1274   if (def->getNumOperands() == 0) {
1275     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
1276       return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false};
1277   }
1278 
1279   // Construct unary operations if subexpression can be built.
1280   if (def->getNumOperands() == 1) {
1281     const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
1282     if (x.has_value()) {
1283       const ExprId e = *x;
1284       if (isa<math::AbsFOp>(def))
1285         return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep};
1286       if (isa<complex::AbsOp>(def))
1287         return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep};
1288       if (isa<math::AbsIOp>(def))
1289         return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep};
1290       if (isa<math::CeilOp>(def))
1291         return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep};
1292       if (isa<math::FloorOp>(def))
1293         return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep};
1294       if (isa<math::SqrtOp>(def))
1295         return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep};
1296       if (isa<complex::SqrtOp>(def))
1297         return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep};
1298       if (isa<math::ExpM1Op>(def))
1299         return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep};
1300       if (isa<complex::Expm1Op>(def))
1301         return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep};
1302       if (isa<math::Log1pOp>(def))
1303         return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep};
1304       if (isa<complex::Log1pOp>(def))
1305         return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep};
1306       if (isa<math::SinOp>(def))
1307         return {addExp(TensorExp::Kind::kSinF, e), hasSpDep};
1308       if (isa<complex::SinOp>(def))
1309         return {addExp(TensorExp::Kind::kSinC, e), hasSpDep};
1310       if (isa<math::TanhOp>(def))
1311         return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep};
1312       if (isa<complex::TanhOp>(def))
1313         return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep};
1314       if (isa<arith::NegFOp>(def))
1315         return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std
1316       if (isa<complex::NegOp>(def))
1317         return {addExp(TensorExp::Kind::kNegC, e), hasSpDep};
1318       if (isa<arith::TruncFOp>(def))
1319         return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep};
1320       if (isa<arith::ExtFOp>(def))
1321         return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep};
1322       if (isa<arith::FPToSIOp>(def))
1323         return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep};
1324       if (isa<arith::FPToUIOp>(def))
1325         return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep};
1326       if (isa<arith::SIToFPOp>(def))
1327         return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep};
1328       if (isa<arith::UIToFPOp>(def))
1329         return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep};
1330       if (isa<arith::ExtSIOp>(def))
1331         return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep};
1332       if (isa<arith::ExtUIOp>(def))
1333         return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep};
1334       if (isa<arith::IndexCastOp>(def))
1335         return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep};
1336       if (isa<arith::TruncIOp>(def))
1337         return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep};
1338       if (isa<complex::ImOp>(def))
1339         return {addExp(TensorExp::Kind::kCIm, e), hasSpDep};
1340       if (isa<complex::ReOp>(def))
1341         return {addExp(TensorExp::Kind::kCRe, e), hasSpDep};
1342       if (isa<arith::BitcastOp>(def))
1343         return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep};
1344       if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1345         if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
1346             isAdmissibleBranch(unop, unop.getAbsentRegion()))
1347           return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep};
1348       }
1349       if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1350         if (isAdmissibleBranch(selop, selop.getRegion()))
1351           return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep};
1352       }
1353     }
1354   }
1355   // Construct binary operations if subexpressions can be built.
1356   // See buildLattices() for an explanation of rejecting certain
1357   // division and shift operations.
1358   if (def->getNumOperands() == 2) {
1359     const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1360     const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1361     bool hasSpDep = xDepSp || yDepSp;
1362     if (x.has_value() && y.has_value()) {
1363       const ExprId e0 = *x;
1364       const ExprId e1 = *y;
1365       if (isa<arith::MulFOp>(def))
1366         return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep};
1367       if (isa<complex::MulOp>(def))
1368         return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep};
1369       if (isa<arith::MulIOp>(def))
1370         return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep};
1371       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1372         return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep};
1373       if (isa<complex::DivOp>(def) && !maybeZero(e1))
1374         return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep};
1375       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1376         return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep};
1377       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1378         return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep};
1379       if (isa<arith::AddFOp>(def))
1380         return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep};
1381       if (isa<complex::AddOp>(def))
1382         return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep};
1383       if (isa<arith::AddIOp>(def))
1384         return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep};
1385       if (isa<arith::SubFOp>(def))
1386         return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep};
1387       if (isa<complex::SubOp>(def))
1388         return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep};
1389       if (isa<arith::SubIOp>(def))
1390         return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep};
1391       if (isa<arith::AndIOp>(def))
1392         return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep};
1393       if (isa<arith::OrIOp>(def))
1394         return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep};
1395       if (isa<arith::XOrIOp>(def))
1396         return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep};
1397       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1398         return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep};
1399       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1400         return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep};
1401       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1402         return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep};
1403       if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
1404         if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1405             ci.getPredicate() == arith::CmpIPredicate::sle &&
1406             ci.getPredicate() == arith::CmpIPredicate::sge &&
1407             ci.getPredicate() == arith::CmpIPredicate::ule &&
1408             ci.getPredicate() == arith::CmpIPredicate::uge) {
1409           // We can not sparsify comparison with equal, this is because 0 <= 0
1410           // yields true, and thus densifies the result.
1411           return {std::nullopt, false};
1412         }
1413 
1414         auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
1415                         ci.getPredicateAttr());
1416         return {e, hasSpDep};
1417       }
1418       if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
1419         if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1420             cf.getPredicate() == arith::CmpFPredicate::OGE &&
1421             cf.getPredicate() == arith::CmpFPredicate::OLE &&
1422             cf.getPredicate() == arith::CmpFPredicate::ONE &&
1423             cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1424             cf.getPredicate() == arith::CmpFPredicate::UGE &&
1425             cf.getPredicate() == arith::CmpFPredicate::ULE &&
1426             cf.getPredicate() == arith::CmpFPredicate::ORD &&
1427             cf.getPredicate() == arith::CmpFPredicate::UNO) {
1428           // We can not sparsify comparison with equal, this is because 0 <= 0
1429           // yields true, and thus densifies the result.
1430           return {std::nullopt, false};
1431         }
1432         auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
1433                         cf.getPredicateAttr());
1434         return {e, hasSpDep};
1435       }
1436       if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1437         if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
1438             (binop.getLeftIdentity() ||
1439              isAdmissibleBranch(binop, binop.getLeftRegion())) &&
1440             (binop.getRightIdentity() ||
1441              isAdmissibleBranch(binop, binop.getRightRegion())))
1442           return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep};
1443       }
1444     }
1445   }
1446   // Construct ternary operations if subexpressions can be built.
1447   if (def->getNumOperands() == 3) {
1448     const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1449     const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1450     const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
1451     bool hasSpDep = xDepSp || yDepSp || zDepSp;
1452     if (x.has_value() && y.has_value() && z.has_value()) {
1453       const ExprId e0 = *x;
1454       const ExprId e1 = *y;
1455       if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1456         if (isAdmissibleBranch(redop, redop.getRegion()))
1457           return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
1458       }
1459     }
1460   }
1461 
1462   // If we reach here, we are dealing with an operation that is not currently
1463   // sparsifiable. We can still generate code for it if all its operands only
1464   // have dense dependencies (i.e., all the values are loaded from dense
1465   // tensors).
1466   if (def->getNumResults() != 1) // only handle single result operation.
1467     return {std::nullopt, false};
1468 
1469   SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
1470   // Builds all the sub-expressions
1471   for (Value operand : def->getOperands())
1472     subExp.push_back(buildTensorExp(op, operand));
1473 
1474   if (llvm::all_of(subExp,
1475                    [](auto e) { return e.first.has_value() && !e.second; })) {
1476     // All the subexpressions can be built and has *no* sparse dependencies.
1477     if (subExp.size() == 2) {
1478       auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1479                       *subExp[1].first, def);
1480       return {e, false};
1481     }
1482     if (subExp.size() == 1) {
1483       auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1484                       detail::kInvalidId, def);
1485       return {e, false};
1486     }
1487   }
1488   // Cannot build.
1489   return {std::nullopt, false};
1490 }
1491 
1492 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1493                            ValueRange vals) {
1494   // Make a clone of overlap region.
1495   Region tmpRegion;
1496   IRMapping mapper;
1497   region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1498   Block &clonedBlock = tmpRegion.front();
1499   YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1500   // Merge cloned block and return yield value.
1501   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1502   rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
1503   Value val = clonedYield.getResult();
1504   rewriter.eraseOp(clonedYield);
1505   rewriter.eraseOp(placeholder);
1506   return val;
1507 }
1508 
1509 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
1510                                Operation *op, Value v0) {
1511   if (!v0)
1512     // Empty input value must be propagated.
1513     return Value();
1514   UnaryOp unop = cast<UnaryOp>(op);
1515   Region &presentRegion = unop.getPresentRegion();
1516   if (presentRegion.empty())
1517     // Uninitialized Value() will be interpreted as missing data in the
1518     // output.
1519     return Value();
1520   return insertYieldOp(rewriter, loc, presentRegion, {v0});
1521 }
1522 
1523 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
1524                                 Operation *op, Value v0, Value v1) {
1525   if (!v0 || !v1)
1526     // Empty input values must be propagated.
1527     return Value();
1528   BinaryOp binop = cast<BinaryOp>(op);
1529   Region &overlapRegion = binop.getOverlapRegion();
1530   if (overlapRegion.empty())
1531     // Uninitialized Value() will be interpreted as missing data in the
1532     // output.
1533     return Value();
1534   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1535 }
1536 
1537 Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
1538                        Value v1) const {
1539   const auto &expr = exp(e);
1540   switch (expr.kind) {
1541   // Leaf.
1542   case TensorExp::Kind::kTensor:
1543   case TensorExp::Kind::kInvariant:
1544   case TensorExp::Kind::kLoopVar:
1545   case TensorExp::Kind::kSynZero:
1546     llvm_unreachable("unexpected non-op");
1547   // Unary operations.
1548   case TensorExp::Kind::kAbsF:
1549     return rewriter.create<math::AbsFOp>(loc, v0);
1550   case TensorExp::Kind::kAbsC: {
1551     auto type = cast<ComplexType>(v0.getType());
1552     auto eltType = cast<FloatType>(type.getElementType());
1553     return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1554   }
1555   case TensorExp::Kind::kAbsI:
1556     return rewriter.create<math::AbsIOp>(loc, v0);
1557   case TensorExp::Kind::kCeilF:
1558     return rewriter.create<math::CeilOp>(loc, v0);
1559   case TensorExp::Kind::kFloorF:
1560     return rewriter.create<math::FloorOp>(loc, v0);
1561   case TensorExp::Kind::kSqrtF:
1562     return rewriter.create<math::SqrtOp>(loc, v0);
1563   case TensorExp::Kind::kSqrtC:
1564     return rewriter.create<complex::SqrtOp>(loc, v0);
1565   case TensorExp::Kind::kExpm1F:
1566     return rewriter.create<math::ExpM1Op>(loc, v0);
1567   case TensorExp::Kind::kExpm1C:
1568     return rewriter.create<complex::Expm1Op>(loc, v0);
1569   case TensorExp::Kind::kLog1pF:
1570     return rewriter.create<math::Log1pOp>(loc, v0);
1571   case TensorExp::Kind::kLog1pC:
1572     return rewriter.create<complex::Log1pOp>(loc, v0);
1573   case TensorExp::Kind::kSinF:
1574     return rewriter.create<math::SinOp>(loc, v0);
1575   case TensorExp::Kind::kSinC:
1576     return rewriter.create<complex::SinOp>(loc, v0);
1577   case TensorExp::Kind::kTanhF:
1578     return rewriter.create<math::TanhOp>(loc, v0);
1579   case TensorExp::Kind::kTanhC:
1580     return rewriter.create<complex::TanhOp>(loc, v0);
1581   case TensorExp::Kind::kNegF:
1582     return rewriter.create<arith::NegFOp>(loc, v0);
1583   case TensorExp::Kind::kNegC:
1584     return rewriter.create<complex::NegOp>(loc, v0);
1585   case TensorExp::Kind::kNegI: // no negi in std
1586     return rewriter.create<arith::SubIOp>(
1587         loc,
1588         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1589                                            rewriter.getZeroAttr(v0.getType())),
1590         v0);
1591   case TensorExp::Kind::kTruncF:
1592     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1593   case TensorExp::Kind::kExtF:
1594     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1595   case TensorExp::Kind::kCastFS:
1596     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1597   case TensorExp::Kind::kCastFU:
1598     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1599   case TensorExp::Kind::kCastSF:
1600     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1601   case TensorExp::Kind::kCastUF:
1602     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1603   case TensorExp::Kind::kCastS:
1604     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1605   case TensorExp::Kind::kCastU:
1606     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1607   case TensorExp::Kind::kCastIdx:
1608     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1609   case TensorExp::Kind::kTruncI:
1610     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1611   case TensorExp::Kind::kCIm: {
1612     auto type = cast<ComplexType>(v0.getType());
1613     auto eltType = cast<FloatType>(type.getElementType());
1614     return rewriter.create<complex::ImOp>(loc, eltType, v0);
1615   }
1616   case TensorExp::Kind::kCRe: {
1617     auto type = cast<ComplexType>(v0.getType());
1618     auto eltType = cast<FloatType>(type.getElementType());
1619     return rewriter.create<complex::ReOp>(loc, eltType, v0);
1620   }
1621   case TensorExp::Kind::kBitCast:
1622     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1623   // Binary operations.
1624   case TensorExp::Kind::kMulF:
1625     return rewriter.create<arith::MulFOp>(loc, v0, v1);
1626   case TensorExp::Kind::kMulC:
1627     return rewriter.create<complex::MulOp>(loc, v0, v1);
1628   case TensorExp::Kind::kMulI:
1629     return rewriter.create<arith::MulIOp>(loc, v0, v1);
1630   case TensorExp::Kind::kDivF:
1631     return rewriter.create<arith::DivFOp>(loc, v0, v1);
1632   case TensorExp::Kind::kDivC:
1633     return rewriter.create<complex::DivOp>(loc, v0, v1);
1634   case TensorExp::Kind::kDivS:
1635     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1636   case TensorExp::Kind::kDivU:
1637     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1638   case TensorExp::Kind::kAddF:
1639     return rewriter.create<arith::AddFOp>(loc, v0, v1);
1640   case TensorExp::Kind::kAddC:
1641     return rewriter.create<complex::AddOp>(loc, v0, v1);
1642   case TensorExp::Kind::kAddI:
1643     return rewriter.create<arith::AddIOp>(loc, v0, v1);
1644   case TensorExp::Kind::kSubF:
1645     return rewriter.create<arith::SubFOp>(loc, v0, v1);
1646   case TensorExp::Kind::kSubC:
1647     return rewriter.create<complex::SubOp>(loc, v0, v1);
1648   case TensorExp::Kind::kSubI:
1649     return rewriter.create<arith::SubIOp>(loc, v0, v1);
1650   case TensorExp::Kind::kAndI:
1651     return rewriter.create<arith::AndIOp>(loc, v0, v1);
1652   case TensorExp::Kind::kOrI:
1653     return rewriter.create<arith::OrIOp>(loc, v0, v1);
1654   case TensorExp::Kind::kXorI:
1655     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1656   case TensorExp::Kind::kShrS:
1657     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1658   case TensorExp::Kind::kShrU:
1659     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1660   case TensorExp::Kind::kShlI:
1661     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1662   case TensorExp::Kind::kCmpI: {
1663     auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr);
1664     return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1);
1665   }
1666   case TensorExp::Kind::kCmpF: {
1667     auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr);
1668     return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1);
1669   }
1670   case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
1671     return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(),
1672                          {v0});
1673   case TensorExp::Kind::kUnary:
1674     return buildUnaryPresent(rewriter, loc, expr.op, v0);
1675   case TensorExp::Kind::kSelect:
1676     return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
1677                          {v0});
1678   case TensorExp::Kind::kBinary:
1679     return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
1680   case TensorExp::Kind::kReduce: {
1681     ReduceOp redOp = cast<ReduceOp>(expr.op);
1682     return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1683   }
1684   case TensorExp::Kind::kDenseOp: {
1685     Operation *actualOp = expr.op;
1686     IRMapping mapping;
1687     mapping.map(actualOp->getOperand(0), v0);
1688     if (actualOp->getNumOperands() == 2)
1689       mapping.map(actualOp->getOperand(1), v1);
1690     return rewriter.clone(*actualOp, mapping)->getResult(0);
1691   }
1692   }
1693   llvm_unreachable("unexpected expression kind in build");
1694 }
1695 
1696 } // namespace sparse_tensor
1697 } // namespace mlir
1698