xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (revision debdf7e0ffe154d85c7c5aee7b2aa88a85eccc0d)
1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 #include <optional>
18 
19 namespace mlir {
20 namespace sparse_tensor {
21 
22 enum class ExpArity {
23   kNullary,
24   kUnary,
25   kBinary,
26 };
27 
28 static ExpArity getExpArity(TensorExp::Kind k) {
29   switch (k) {
30   // Leaf.
31   case TensorExp::Kind::kTensor:
32   case TensorExp::Kind::kInvariant:
33   case TensorExp::Kind::kLoopVar:
34     return ExpArity::kNullary;
35   case TensorExp::Kind::kAbsF:
36   case TensorExp::Kind::kAbsC:
37   case TensorExp::Kind::kAbsI:
38   case TensorExp::Kind::kCeilF:
39   case TensorExp::Kind::kFloorF:
40   case TensorExp::Kind::kSqrtF:
41   case TensorExp::Kind::kSqrtC:
42   case TensorExp::Kind::kExpm1F:
43   case TensorExp::Kind::kExpm1C:
44   case TensorExp::Kind::kLog1pF:
45   case TensorExp::Kind::kLog1pC:
46   case TensorExp::Kind::kSinF:
47   case TensorExp::Kind::kSinC:
48   case TensorExp::Kind::kTanhF:
49   case TensorExp::Kind::kTanhC:
50   case TensorExp::Kind::kTruncF:
51   case TensorExp::Kind::kExtF:
52   case TensorExp::Kind::kCastFS:
53   case TensorExp::Kind::kCastFU:
54   case TensorExp::Kind::kCastSF:
55   case TensorExp::Kind::kCastUF:
56   case TensorExp::Kind::kCastS:
57   case TensorExp::Kind::kCastU:
58   case TensorExp::Kind::kCastIdx:
59   case TensorExp::Kind::kTruncI:
60   case TensorExp::Kind::kCIm:
61   case TensorExp::Kind::kCRe:
62   case TensorExp::Kind::kBitCast:
63   case TensorExp::Kind::kBinaryBranch:
64   case TensorExp::Kind::kUnary:
65   case TensorExp::Kind::kSelect:
66   case TensorExp::Kind::kNegF:
67   case TensorExp::Kind::kNegC:
68   case TensorExp::Kind::kNegI:
69     return ExpArity::kUnary;
70   // Binary operations.
71   case TensorExp::Kind::kDivF:
72   case TensorExp::Kind::kDivC:
73   case TensorExp::Kind::kDivS:
74   case TensorExp::Kind::kDivU:
75   case TensorExp::Kind::kShrS:
76   case TensorExp::Kind::kShrU:
77   case TensorExp::Kind::kShlI:
78   case TensorExp::Kind::kMulF:
79   case TensorExp::Kind::kMulC:
80   case TensorExp::Kind::kMulI:
81   case TensorExp::Kind::kAndI:
82   case TensorExp::Kind::kAddF:
83   case TensorExp::Kind::kAddC:
84   case TensorExp::Kind::kAddI:
85   case TensorExp::Kind::kOrI:
86   case TensorExp::Kind::kXorI:
87   case TensorExp::Kind::kBinary:
88   case TensorExp::Kind::kReduce:
89   case TensorExp::Kind::kSubF:
90   case TensorExp::Kind::kSubC:
91   case TensorExp::Kind::kSubI:
92     return ExpArity::kBinary;
93   }
94   llvm_unreachable("unexpected kind");
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // Constructors.
99 //===----------------------------------------------------------------------===//
100 
101 TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
102                      Operation *o)
103     : kind(k), val(v), op(o) {
104   switch (kind) {
105   // Leaf.
106   case TensorExp::Kind::kTensor:
107     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
108     tensor = x;
109     return;
110   case TensorExp::Kind::kInvariant:
111     assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o);
112     return;
113   case TensorExp::Kind::kLoopVar:
114     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
115     loop = x;
116     return;
117   // Unary operations.
118   case TensorExp::Kind::kAbsF:
119   case TensorExp::Kind::kAbsC:
120   case TensorExp::Kind::kAbsI:
121   case TensorExp::Kind::kCeilF:
122   case TensorExp::Kind::kFloorF:
123   case TensorExp::Kind::kSqrtF:
124   case TensorExp::Kind::kSqrtC:
125   case TensorExp::Kind::kExpm1F:
126   case TensorExp::Kind::kExpm1C:
127   case TensorExp::Kind::kLog1pF:
128   case TensorExp::Kind::kLog1pC:
129   case TensorExp::Kind::kSinF:
130   case TensorExp::Kind::kSinC:
131   case TensorExp::Kind::kTanhF:
132   case TensorExp::Kind::kTanhC:
133   case TensorExp::Kind::kNegF:
134   case TensorExp::Kind::kNegC:
135   case TensorExp::Kind::kNegI:
136   case TensorExp::Kind::kCIm:
137   case TensorExp::Kind::kCRe:
138     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
139     children.e0 = x;
140     children.e1 = y;
141     return;
142   case TensorExp::Kind::kTruncF:
143   case TensorExp::Kind::kExtF:
144   case TensorExp::Kind::kCastFS:
145   case TensorExp::Kind::kCastFU:
146   case TensorExp::Kind::kCastSF:
147   case TensorExp::Kind::kCastUF:
148   case TensorExp::Kind::kCastS:
149   case TensorExp::Kind::kCastU:
150   case TensorExp::Kind::kCastIdx:
151   case TensorExp::Kind::kTruncI:
152   case TensorExp::Kind::kBitCast:
153     assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o);
154     children.e0 = x;
155     children.e1 = y;
156     return;
157   case TensorExp::Kind::kBinaryBranch:
158   case TensorExp::Kind::kSelect:
159     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o);
160     children.e0 = x;
161     children.e1 = y;
162     return;
163   case TensorExp::Kind::kUnary:
164     // No assertion on y can be made, as the branching paths involve both
165     // a unary (`mapSet`) and binary (`disjSet`) pathway.
166     assert(x != detail::kInvalidId && !v && o);
167     children.e0 = x;
168     children.e1 = y;
169     return;
170   // Binary operations.
171   case TensorExp::Kind::kMulF:
172   case TensorExp::Kind::kMulC:
173   case TensorExp::Kind::kMulI:
174   case TensorExp::Kind::kDivF:
175   case TensorExp::Kind::kDivC:
176   case TensorExp::Kind::kDivS:
177   case TensorExp::Kind::kDivU:
178   case TensorExp::Kind::kAddF:
179   case TensorExp::Kind::kAddC:
180   case TensorExp::Kind::kAddI:
181   case TensorExp::Kind::kSubF:
182   case TensorExp::Kind::kSubC:
183   case TensorExp::Kind::kSubI:
184   case TensorExp::Kind::kAndI:
185   case TensorExp::Kind::kOrI:
186   case TensorExp::Kind::kXorI:
187   case TensorExp::Kind::kShrS:
188   case TensorExp::Kind::kShrU:
189   case TensorExp::Kind::kShlI:
190     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
191     children.e0 = x;
192     children.e1 = y;
193     return;
194   case TensorExp::Kind::kBinary:
195   case TensorExp::Kind::kReduce:
196     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o);
197     children.e0 = x;
198     children.e1 = y;
199     return;
200   }
201   llvm_unreachable("unexpected kind");
202 }
203 
204 Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
205                unsigned numFilterLoops, unsigned maxLvlRank)
206     : outTensor(numInputOutputTensors - 1),
207       syntheticTensor(numInputOutputTensors),
208       numTensors(numInputOutputTensors + 1), numNativeLoops(numNativeLoops),
209       numLoops(numNativeLoops + numFilterLoops), hasSparseOut(false),
210       lvlTypes(numTensors,
211                std::vector<DimLevelType>(numLoops, DimLevelType::Undef)),
212       loopToLvl(numTensors,
213                 std::vector<std::optional<Level>>(numLoops, std::nullopt)),
214       lvlToLoop(numTensors,
215                 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
216       loopToDependencies(
217           numLoops, std::vector<std::optional<std::pair<Level, DimLevelType>>>(
218                         numTensors, std::nullopt)),
219       levelToDependentLoop(numTensors, std::vector<std::vector<LoopId>>(
220                                            maxLvlRank, std::vector<LoopId>())),
221       loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
222 
223 //===----------------------------------------------------------------------===//
224 // Lattice methods.
225 //===----------------------------------------------------------------------===//
226 
227 ExprId Merger::addTensorExp(TensorId t) {
228   assert(isValidTensorId(t));
229   const ExprId eNew(tensorExps.size());
230   tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId,
231                           Value(), nullptr);
232   return eNew;
233 }
234 
235 ExprId Merger::addLoopVarExp(LoopId i) {
236   assert(isValidLoopId(i));
237   const ExprId eNew(tensorExps.size());
238   tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId,
239                           Value(), nullptr);
240   return eNew;
241 }
242 
243 ExprId Merger::addInvariantExp(Value v) {
244   const ExprId eNew(tensorExps.size());
245   tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId,
246                           detail::kInvalidId, v, nullptr);
247   return eNew;
248 }
249 
250 ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op) {
251   assert(k > TensorExp::Kind::kLoopVar);
252   const ExprId eNew(tensorExps.size());
253   tensorExps.emplace_back(k, e0, e1, Value(), op);
254   return eNew;
255 }
256 
257 ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op) {
258   assert(k > TensorExp::Kind::kLoopVar);
259   const ExprId eNew(tensorExps.size());
260   tensorExps.emplace_back(k, e, detail::kInvalidId, v, op);
261   return eNew;
262 }
263 
264 LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {
265   const LatPointId pNew(latPoints.size());
266   const unsigned size = numLoops * numTensors;
267   const TensorLoopId b = makeTensorLoopId(t, i);
268   latPoints.emplace_back(size, e);
269   latPoints[pNew].bits.set(b);
270   return pNew;
271 }
272 
273 LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
274   assert(bits.size() == numLoops * numTensors);
275   const LatPointId pNew(latPoints.size());
276   latPoints.emplace_back(bits, e);
277   return pNew;
278 }
279 
280 LatSetId Merger::addSet() {
281   const LatSetId sNew(latSets.size());
282   latSets.emplace_back();
283   return sNew;
284 }
285 
286 LatPointId Merger::conjLat(TensorExp::Kind kind, LatPointId p0, LatPointId p1,
287                            Operation *op) {
288   const LatPointId pNew(latPoints.size());
289   const auto &point0 = lat(p0);
290   const auto &point1 = lat(p1);
291   BitVector bits(point0.bits);
292   bits |= point1.bits;
293   const ExprId e = addExp(kind, point0.exp, point1.exp, op);
294   latPoints.emplace_back(bits, e);
295   return pNew;
296 }
297 
298 LatSetId Merger::conjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
299                          Operation *op) {
300   const LatSetId sNew = addSet();
301   auto &setNew = latSets[sNew];
302   for (const LatPointId p0 : set(s0))
303     for (const LatPointId p1 : set(s1))
304       setNew.push_back(conjLat(kind, p0, p1, op));
305   return sNew;
306 }
307 
308 LatSetId Merger::disjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
309                          Operation *op) {
310   const LatSetId sNew = conjSet(kind, s0, s1, op);
311   // Followed by all in s0.
312   latSets[sNew].append(latSets[s0]);
313   // Map binary 0-y to unary -y.
314   // TODO: move this if-else logic into buildLattices
315   if (kind == TensorExp::Kind::kSubF)
316     s1 = mapSet(TensorExp::Kind::kNegF, s1);
317   else if (kind == TensorExp::Kind::kSubC)
318     s1 = mapSet(TensorExp::Kind::kNegC, s1);
319   else if (kind == TensorExp::Kind::kSubI)
320     s1 = mapSet(TensorExp::Kind::kNegI, s1);
321   // Followed by all in s1.
322   latSets[sNew].append(latSets[s1]);
323   return sNew;
324 }
325 
326 LatSetId Merger::combiSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
327                           Operation *orig, bool includeLeft,
328                           TensorExp::Kind ltrans, Operation *opleft,
329                           bool includeRight, TensorExp::Kind rtrans,
330                           Operation *opright) {
331   const LatSetId sNew = conjSet(kind, s0, s1, orig);
332   // Left Region.
333   if (includeLeft) {
334     if (opleft)
335       s0 = mapSet(ltrans, s0, Value(), opleft);
336     latSets[sNew].append(latSets[s0]);
337   }
338   // Right Region.
339   if (includeRight) {
340     if (opright)
341       s1 = mapSet(rtrans, s1, Value(), opright);
342     latSets[sNew].append(latSets[s1]);
343   }
344   return sNew;
345 }
346 
347 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
348                         Operation *op) {
349   assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect);
350   const LatSetId sNew = addSet();
351   auto &setNew = latSets[sNew];
352   for (const LatPointId p : set(s0)) {
353     const auto &point = latPoints[p];
354     setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op)));
355   }
356   return sNew;
357 }
358 
359 LatSetId Merger::optimizeSet(LatSetId s0) {
360   const LatSetId sNew = addSet();
361   auto &setNew = latSets[sNew];
362   const auto &set0 = set(s0);
363   assert(!set0.empty());
364   const LatPointId p0 = set0[0];
365   for (const LatPointId p1 : set0) {
366     bool add = true;
367     if (p0 != p1) {
368       // Check whether this is a straightforward copy.
369       if (expIsTensor(latPoints[p1].exp, outTensor))
370         continue;
371       // Check whether this conjunction is already covered.
372       for (const LatPointId p2 : setNew) {
373         assert(!latGT(p1, p2)); // Lj => Li would be bad
374         if (onlyDenseDiff(p2, p1)) {
375           add = false;
376           break;
377         }
378       }
379       assert(!add || latGT(p0, p1));
380     }
381     if (add)
382       setNew.push_back(p1);
383   }
384   for (const LatPointId p : setNew)
385     latPoints[p].simple = simplifyCond(sNew, p);
386   return sNew;
387 }
388 
389 BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
390   // First determine if this lattice point is a *singleton*, i.e.,
391   // the last point in a lattice, no other is less than this one.
392   bool isSingleton = true;
393   for (const LatPointId p1 : set(s0)) {
394     if (p0 != p1 && latGT(p0, p1)) {
395       isSingleton = false;
396       break;
397     }
398   }
399 
400   BitVector simple(latPoints[p0].bits);
401   bool reset = isSingleton && hasAnySparse(simple);
402   const TensorLoopId be = simple.size();
403   TensorLoopId offset = 0; // relative to the end
404   if (!reset)
405     // Starts resetting from a dense level, so that the first bit (if kept)
406     // is not undefined level-type.
407     for (unsigned b = 0; b < be; b++) {
408       if (simple[b] && isDenseDLT(getLvlType(TensorLoopId{b}))) {
409         offset = be - b - 1; // relative to the end
410         break;
411       }
412     }
413 
414   // Now apply the two basic rules. We also iterate the bits reversely to always
415   // keep the rightmost bit (which could possibly be a synthetic tensor).
416   for (unsigned b = be - 1 - offset, i = 0; i < be;
417        b = b == 0 ? be - 1 : b - 1, i++) {
418     // Slice on dense level has `locate` property as well, and can be optimized.
419     if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
420       const auto dlt = getLvlType(b);
421       if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) && !isCompressedWithHiDLT(dlt)) {
422         if (reset)
423           simple.reset(b);
424         reset = true;
425       }
426     }
427   }
428   return simple;
429 }
430 
431 bool Merger::latGT(LatPointId i, LatPointId j) const {
432   const BitVector &bitsi = lat(i).bits;
433   const BitVector &bitsj = lat(j).bits;
434   assert(bitsi.size() == bitsj.size());
435   if (bitsi.count() > bitsj.count()) {
436     for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
437       if (bitsj[b] && !bitsi[b])
438         return false;
439     return true;
440   }
441   return false;
442 }
443 
444 bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
445   BitVector tmp(latPoints[j].bits);
446   tmp ^= latPoints[i].bits;
447   return !hasAnySparse(tmp);
448 }
449 
450 bool Merger::expContainsTensor(ExprId e, TensorId t) const {
451   const auto &expr = exp(e);
452   // First we check `expIsTensor`.
453   if (expr.kind == TensorExp::Kind::kTensor)
454     return expr.tensor == t;
455 
456   switch (getExpArity(expr.kind)) {
457   case ExpArity::kNullary:
458     return false;
459   case ExpArity::kUnary: {
460     const ExprId e0 = expr.children.e0;
461     return expContainsTensor(e0, t);
462   }
463   case ExpArity::kBinary: {
464     const ExprId e0 = expr.children.e0;
465     const ExprId e1 = expr.children.e1;
466     return expContainsTensor(e0, t) || expContainsTensor(e1, t);
467   }
468   }
469   llvm_unreachable("unexpected arity");
470 }
471 
472 bool Merger::hasNegateOnOut(ExprId e) const {
473   const auto &expr = exp(e);
474   switch (expr.kind) {
475   case TensorExp::Kind::kNegF:
476   case TensorExp::Kind::kNegC:
477   case TensorExp::Kind::kNegI:
478     return expContainsTensor(expr.children.e0, outTensor);
479   case TensorExp::Kind::kSubF:
480   case TensorExp::Kind::kSubC:
481   case TensorExp::Kind::kSubI:
482     return expContainsTensor(expr.children.e1, outTensor) ||
483            hasNegateOnOut(expr.children.e0);
484   default: {
485     switch (getExpArity(expr.kind)) {
486     case ExpArity::kNullary:
487       return false;
488     case ExpArity::kUnary:
489       return hasNegateOnOut(expr.children.e0);
490     case ExpArity::kBinary:
491       return hasNegateOnOut(expr.children.e0) ||
492              hasNegateOnOut(expr.children.e1);
493     }
494   }
495   }
496   llvm_unreachable("unexpected kind");
497 }
498 
499 bool Merger::isSingleCondition(TensorId t, ExprId e) const {
500   assert(isValidTensorId(t));
501   const auto &expr = exp(e);
502   switch (expr.kind) {
503   // Leaf.
504   case TensorExp::Kind::kTensor:
505     return expr.tensor == t;
506   case TensorExp::Kind::kInvariant:
507   case TensorExp::Kind::kLoopVar:
508     return false;
509   // Unary operations.
510   case TensorExp::Kind::kAbsF:
511   case TensorExp::Kind::kAbsC:
512   case TensorExp::Kind::kAbsI:
513   case TensorExp::Kind::kCeilF:
514   case TensorExp::Kind::kFloorF:
515   case TensorExp::Kind::kSqrtF:
516   case TensorExp::Kind::kSqrtC:
517   case TensorExp::Kind::kExpm1F:
518   case TensorExp::Kind::kExpm1C:
519   case TensorExp::Kind::kLog1pF:
520   case TensorExp::Kind::kLog1pC:
521   case TensorExp::Kind::kSinF:
522   case TensorExp::Kind::kSinC:
523   case TensorExp::Kind::kTanhF:
524   case TensorExp::Kind::kTanhC:
525   case TensorExp::Kind::kNegF:
526   case TensorExp::Kind::kNegC:
527   case TensorExp::Kind::kNegI:
528   case TensorExp::Kind::kTruncF:
529   case TensorExp::Kind::kExtF:
530   case TensorExp::Kind::kCastFS:
531   case TensorExp::Kind::kCastFU:
532   case TensorExp::Kind::kCastSF:
533   case TensorExp::Kind::kCastUF:
534   case TensorExp::Kind::kCastS:
535   case TensorExp::Kind::kCastU:
536   case TensorExp::Kind::kCastIdx:
537   case TensorExp::Kind::kTruncI:
538   case TensorExp::Kind::kCIm:
539   case TensorExp::Kind::kCRe:
540   case TensorExp::Kind::kBitCast:
541   case TensorExp::Kind::kUnary:
542     return isSingleCondition(t, expr.children.e0);
543   case TensorExp::Kind::kBinaryBranch:
544   case TensorExp::Kind::kSelect:
545     return false;
546   // Binary operations.
547   case TensorExp::Kind::kDivF: // note: x / c only
548   case TensorExp::Kind::kDivC:
549   case TensorExp::Kind::kDivS:
550   case TensorExp::Kind::kDivU:
551     assert(!maybeZero(expr.children.e1));
552     return isSingleCondition(t, expr.children.e0);
553   case TensorExp::Kind::kShrS: // note: x >> inv only
554   case TensorExp::Kind::kShrU:
555   case TensorExp::Kind::kShlI:
556     assert(isInvariant(expr.children.e1));
557     return isSingleCondition(t, expr.children.e0);
558   case TensorExp::Kind::kMulF:
559   case TensorExp::Kind::kMulC:
560   case TensorExp::Kind::kMulI:
561   case TensorExp::Kind::kAndI:
562   case TensorExp::Kind::kReduce:
563     if (isSingleCondition(t, expr.children.e0))
564       return isSingleCondition(t, expr.children.e1) ||
565              isInvariant(expr.children.e1);
566     if (isSingleCondition(t, expr.children.e1))
567       return isInvariant(expr.children.e0);
568     return false;
569   case TensorExp::Kind::kAddF:
570   case TensorExp::Kind::kAddC:
571   case TensorExp::Kind::kAddI:
572     return isSingleCondition(t, expr.children.e0) &&
573            isSingleCondition(t, expr.children.e1);
574   case TensorExp::Kind::kSubF:
575   case TensorExp::Kind::kSubC:
576   case TensorExp::Kind::kSubI:
577   case TensorExp::Kind::kOrI:
578   case TensorExp::Kind::kXorI:
579   case TensorExp::Kind::kBinary:
580     return false;
581   }
582   llvm_unreachable("unexpected kind");
583 }
584 
585 bool Merger::hasAnySparse(const BitVector &bits) const {
586   for (TensorLoopId b : bits.set_bits()) {
587     const auto dlt = getLvlType(b);
588     if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || isCompressedWithHiDLT(dlt))
589       return true;
590   }
591   return hasSparseIdxReduction(bits);
592 }
593 
594 bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
595   for (TensorLoopId b : bits.set_bits())
596     if (isSparseLvlWithNonTrivialIdxExp(b))
597       return true;
598   return false;
599 }
600 
601 #ifndef NDEBUG
602 
603 //===----------------------------------------------------------------------===//
604 // Print methods (for debugging).
605 //===----------------------------------------------------------------------===//
606 
607 static const char *kindToOpSymbol(TensorExp::Kind kind) {
608   switch (kind) {
609   // Leaf.
610   case TensorExp::Kind::kTensor:
611     return "tensor";
612   case TensorExp::Kind::kInvariant:
613     return "invariant";
614   case TensorExp::Kind::kLoopVar:
615     return "index";
616   // Unary operations.
617   case TensorExp::Kind::kAbsF:
618   case TensorExp::Kind::kAbsC:
619   case TensorExp::Kind::kAbsI:
620     return "abs";
621   case TensorExp::Kind::kCeilF:
622     return "ceil";
623   case TensorExp::Kind::kFloorF:
624     return "floor";
625   case TensorExp::Kind::kSqrtF:
626   case TensorExp::Kind::kSqrtC:
627     return "sqrt";
628   case TensorExp::Kind::kExpm1F:
629   case TensorExp::Kind::kExpm1C:
630     return "expm1";
631   case TensorExp::Kind::kLog1pF:
632   case TensorExp::Kind::kLog1pC:
633     return "log1p";
634   case TensorExp::Kind::kSinF:
635   case TensorExp::Kind::kSinC:
636     return "sin";
637   case TensorExp::Kind::kTanhF:
638   case TensorExp::Kind::kTanhC:
639     return "tanh";
640   case TensorExp::Kind::kNegF:
641   case TensorExp::Kind::kNegC:
642   case TensorExp::Kind::kNegI:
643     return "-";
644   case TensorExp::Kind::kTruncF:
645   case TensorExp::Kind::kExtF:
646   case TensorExp::Kind::kCastFS:
647   case TensorExp::Kind::kCastFU:
648   case TensorExp::Kind::kCastSF:
649   case TensorExp::Kind::kCastUF:
650   case TensorExp::Kind::kCastS:
651   case TensorExp::Kind::kCastU:
652   case TensorExp::Kind::kCastIdx:
653   case TensorExp::Kind::kTruncI:
654   case TensorExp::Kind::kCIm:
655     return "complex.im";
656   case TensorExp::Kind::kCRe:
657     return "complex.re";
658   case TensorExp::Kind::kBitCast:
659     return "cast";
660   case TensorExp::Kind::kBinaryBranch:
661     return "binary_branch";
662   case TensorExp::Kind::kUnary:
663     return "unary";
664   case TensorExp::Kind::kSelect:
665     return "select";
666   // Binary operations.
667   case TensorExp::Kind::kMulF:
668   case TensorExp::Kind::kMulC:
669   case TensorExp::Kind::kMulI:
670     return "*";
671   case TensorExp::Kind::kDivF:
672   case TensorExp::Kind::kDivC:
673   case TensorExp::Kind::kDivS:
674   case TensorExp::Kind::kDivU:
675     return "/";
676   case TensorExp::Kind::kAddF:
677   case TensorExp::Kind::kAddC:
678   case TensorExp::Kind::kAddI:
679     return "+";
680   case TensorExp::Kind::kSubF:
681   case TensorExp::Kind::kSubC:
682   case TensorExp::Kind::kSubI:
683     return "-";
684   case TensorExp::Kind::kAndI:
685     return "&";
686   case TensorExp::Kind::kOrI:
687     return "|";
688   case TensorExp::Kind::kXorI:
689     return "^";
690   case TensorExp::Kind::kShrS:
691     return "a>>";
692   case TensorExp::Kind::kShrU:
693     return ">>";
694   case TensorExp::Kind::kShlI:
695     return "<<";
696   case TensorExp::Kind::kBinary:
697     return "binary";
698   case TensorExp::Kind::kReduce:
699     return "reduce";
700   }
701   llvm_unreachable("unexpected kind for symbol");
702 }
703 
704 void Merger::dumpExp(ExprId e) const {
705   const auto &expr = exp(e);
706   switch (expr.kind) {
707   // Leaf.
708   case TensorExp::Kind::kTensor:
709     if (expr.tensor == syntheticTensor)
710       llvm::dbgs() << "synthetic_";
711     else if (expr.tensor == outTensor)
712       llvm::dbgs() << "output_";
713     llvm::dbgs() << "tensor_" << expr.tensor;
714     break;
715   case TensorExp::Kind::kInvariant:
716     llvm::dbgs() << "invariant";
717     break;
718   case TensorExp::Kind::kLoopVar:
719     llvm::dbgs() << "loopvar_" << expr.loop;
720     break;
721   // Unary operations.
722   case TensorExp::Kind::kAbsF:
723   case TensorExp::Kind::kAbsC:
724   case TensorExp::Kind::kAbsI:
725   case TensorExp::Kind::kCeilF:
726   case TensorExp::Kind::kFloorF:
727   case TensorExp::Kind::kSqrtF:
728   case TensorExp::Kind::kSqrtC:
729   case TensorExp::Kind::kExpm1F:
730   case TensorExp::Kind::kExpm1C:
731   case TensorExp::Kind::kLog1pF:
732   case TensorExp::Kind::kLog1pC:
733   case TensorExp::Kind::kSinF:
734   case TensorExp::Kind::kSinC:
735   case TensorExp::Kind::kTanhF:
736   case TensorExp::Kind::kTanhC:
737   case TensorExp::Kind::kNegF:
738   case TensorExp::Kind::kNegC:
739   case TensorExp::Kind::kNegI:
740   case TensorExp::Kind::kTruncF:
741   case TensorExp::Kind::kExtF:
742   case TensorExp::Kind::kCastFS:
743   case TensorExp::Kind::kCastFU:
744   case TensorExp::Kind::kCastSF:
745   case TensorExp::Kind::kCastUF:
746   case TensorExp::Kind::kCastS:
747   case TensorExp::Kind::kCastU:
748   case TensorExp::Kind::kCastIdx:
749   case TensorExp::Kind::kTruncI:
750   case TensorExp::Kind::kCIm:
751   case TensorExp::Kind::kCRe:
752   case TensorExp::Kind::kBitCast:
753   case TensorExp::Kind::kBinaryBranch:
754   case TensorExp::Kind::kUnary:
755   case TensorExp::Kind::kSelect:
756     llvm::dbgs() << kindToOpSymbol(expr.kind) << " ";
757     dumpExp(expr.children.e0);
758     break;
759   // Binary operations.
760   case TensorExp::Kind::kMulF:
761   case TensorExp::Kind::kMulC:
762   case TensorExp::Kind::kMulI:
763   case TensorExp::Kind::kDivF:
764   case TensorExp::Kind::kDivC:
765   case TensorExp::Kind::kDivS:
766   case TensorExp::Kind::kDivU:
767   case TensorExp::Kind::kAddF:
768   case TensorExp::Kind::kAddC:
769   case TensorExp::Kind::kAddI:
770   case TensorExp::Kind::kSubF:
771   case TensorExp::Kind::kSubC:
772   case TensorExp::Kind::kSubI:
773   case TensorExp::Kind::kAndI:
774   case TensorExp::Kind::kOrI:
775   case TensorExp::Kind::kXorI:
776   case TensorExp::Kind::kShrS:
777   case TensorExp::Kind::kShrU:
778   case TensorExp::Kind::kShlI:
779   case TensorExp::Kind::kBinary:
780   case TensorExp::Kind::kReduce:
781     llvm::dbgs() << "(";
782     dumpExp(expr.children.e0);
783     llvm::dbgs() << " " << kindToOpSymbol(expr.kind) << " ";
784     dumpExp(expr.children.e1);
785     llvm::dbgs() << ")";
786     break;
787   }
788 }
789 
790 void Merger::dumpLat(LatPointId p) const {
791   const auto &point = lat(p);
792   llvm::dbgs() << "lat(";
793   dumpBits(point.bits);
794   llvm::dbgs() << " :";
795   dumpBits(point.simple);
796   llvm::dbgs() << " : ";
797   dumpExp(point.exp);
798   llvm::dbgs() << " )\n";
799 }
800 
801 void Merger::dumpSet(LatSetId s) const {
802   const auto &ss = set(s);
803   llvm::dbgs() << "{ #" << ss.size() << "\n";
804   for (const LatPointId p : ss) {
805     llvm::dbgs() << "  ";
806     dumpLat(p);
807   }
808   llvm::dbgs() << "}\n";
809 }
810 
811 void Merger::dumpBits(const BitVector &bits) const {
812   for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
813     if (bits[b]) {
814       const TensorId t = tensor(b);
815       const LoopId i = loop(b);
816       const auto dlt = lvlTypes[t][i];
817       if (isLvlWithNonTrivialIdxExp(b))
818         llvm::dbgs() << " DEP_" << t << "_" << i;
819       else
820         llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt);
821     }
822   }
823 }
824 
825 #endif // NDEBUG
826 
827 //===----------------------------------------------------------------------===//
828 // Builder methods.
829 //===----------------------------------------------------------------------===//
830 
831 LatSetId Merger::buildLattices(ExprId e, LoopId i) {
832   // NOTE: The `expr` reference will be invalidated by recursive calls
833   // (and any other method that may add new expressions); therefore, the
834   // code below must make sure to copy fields of `expr` into local variables
835   // before making any recursive calls.
836   const auto &expr = exp(e);
837   const TensorExp::Kind kind = expr.kind;
838   switch (kind) {
839   // Leaf.
840   case TensorExp::Kind::kTensor:
841   case TensorExp::Kind::kInvariant:
842   case TensorExp::Kind::kLoopVar: {
843     // Either the loop-var is really used in the tensor expression, or it is
844     // set to the undefined loop-var in that level. An invariant expression,
845     // a proper index value, and a truly dynamic sparse output tensor are set
846     // to a synthetic tensor with undefined indices only to ensure the
847     // iteration space is not skipped as a result of their contents.
848     const LatSetId s = addSet();
849     TensorId t = syntheticTensor;
850     if (kind == TensorExp::Kind::kTensor) {
851       t = expr.tensor;
852       if (hasSparseOut && t == outTensor)
853         t = syntheticTensor;
854     }
855     latSets[s].push_back(addLat(t, i, e));
856     return s;
857   }
858   // Unary operations.
859   case TensorExp::Kind::kAbsF:
860   case TensorExp::Kind::kAbsC:
861   case TensorExp::Kind::kAbsI:
862   case TensorExp::Kind::kCeilF:
863   case TensorExp::Kind::kFloorF:
864   case TensorExp::Kind::kSqrtF:
865   case TensorExp::Kind::kSqrtC:
866   case TensorExp::Kind::kExpm1F:
867   case TensorExp::Kind::kExpm1C:
868   case TensorExp::Kind::kLog1pF:
869   case TensorExp::Kind::kLog1pC:
870   case TensorExp::Kind::kSinF:
871   case TensorExp::Kind::kSinC:
872   case TensorExp::Kind::kTanhF:
873   case TensorExp::Kind::kTanhC:
874   case TensorExp::Kind::kNegF:
875   case TensorExp::Kind::kNegC:
876   case TensorExp::Kind::kNegI:
877   case TensorExp::Kind::kTruncF:
878   case TensorExp::Kind::kExtF:
879   case TensorExp::Kind::kCastFS:
880   case TensorExp::Kind::kCastFU:
881   case TensorExp::Kind::kCastSF:
882   case TensorExp::Kind::kCastUF:
883   case TensorExp::Kind::kCastS:
884   case TensorExp::Kind::kCastU:
885   case TensorExp::Kind::kCastIdx:
886   case TensorExp::Kind::kTruncI:
887   case TensorExp::Kind::kCIm:
888   case TensorExp::Kind::kCRe:
889   case TensorExp::Kind::kBitCast:
890     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
891     // lattice set of the operand through the operator into a new set.
892     //
893     //  -y|!y | y |
894     //  --+---+---+
895     //    | 0 |-y |
896     {
897       const ExprId e0 = expr.children.e0;
898       const Value v = expr.val;
899       return mapSet(kind, buildLattices(e0, i), v);
900     }
901   case TensorExp::Kind::kBinaryBranch:
902   case TensorExp::Kind::kSelect:
903     // The left or right half of a binary operation which has already
904     // been split into separate operations for each region.
905     {
906       const ExprId e0 = expr.children.e0;
907       Operation *const op = expr.op;
908       return mapSet(kind, buildLattices(e0, i), Value(), op);
909     }
910   case TensorExp::Kind::kUnary:
911     // A custom unary operation.
912     //
913     //  op y|    !y    |     y      |
914     //  ----+----------+------------+
915     //      | absent() | present(y) |
916     {
917       const ExprId e0 = expr.children.e0;
918       UnaryOp unop = cast<UnaryOp>(expr.op);
919       const LatSetId child0 = buildLattices(e0, i);
920       Region &absentRegion = unop.getAbsentRegion();
921       if (absentRegion.empty()) {
922         // Simple mapping over existing values.
923         return mapSet(kind, child0, Value(), unop);
924       }
925       // Use a disjunction with `unop` on the left and the absent value as an
926       // invariant on the right.
927       Block &absentBlock = absentRegion.front();
928       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
929       const Value absentVal = absentYield.getResult();
930       const ExprId rhs = addInvariantExp(absentVal);
931       return disjSet(kind, child0, buildLattices(rhs, i), unop);
932     }
933   // Binary operations.
934   case TensorExp::Kind::kMulF:
935   case TensorExp::Kind::kMulC:
936   case TensorExp::Kind::kMulI:
937   case TensorExp::Kind::kAndI:
938     // A multiplicative operation only needs to be performed
939     // for the conjunction of sparse iteration spaces.
940     //
941     //  x*y|!y | y |
942     //  ---+---+---+
943     //  !x | 0 | 0 |
944     //   x | 0 |x*y|
945     //
946     // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
947     {
948       const ExprId e0 = expr.children.e0;
949       const ExprId e1 = expr.children.e1;
950       return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i));
951     }
952   case TensorExp::Kind::kDivF:
953   case TensorExp::Kind::kDivC:
954   case TensorExp::Kind::kDivS:
955   case TensorExp::Kind::kDivU:
956     // A division is tricky, since 0/0, 0/c, c/0 all have
957     // specific outcomes for floating-point and integers.
958     // Thus, we need to traverse the full iteration space.
959     //
960     //  x/y|!y | y |
961     //  ---+---+---+
962     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
963     //   x |x/0|x/y|  INT: x/0=exception for any x
964     //
965     // TODO: for now we "fixed" this by only accepting x/c cases
966     //       during expression building, so that the conjunction
967     //       rules applies (viz. x/c = x*(1/c) as far as lattice
968     //       construction is concerned).
969     {
970       const ExprId e0 = expr.children.e0;
971       const ExprId e1 = expr.children.e1;
972       assert(!maybeZero(e1));
973       return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i));
974     }
975   case TensorExp::Kind::kAddF:
976   case TensorExp::Kind::kAddC:
977   case TensorExp::Kind::kAddI:
978   case TensorExp::Kind::kSubF:
979   case TensorExp::Kind::kSubC:
980   case TensorExp::Kind::kSubI:
981   case TensorExp::Kind::kOrI:
982   case TensorExp::Kind::kXorI:
983     // An additive operation needs to be performed
984     // for the disjunction of sparse iteration spaces.
985     //
986     //  x+y|!y | y |    x-y|!y | y |
987     //  ---+---+---+    ---+---+---+
988     //  !x | 0 | y |    !x | 0 |-y |
989     //   x | x |x+y|     x | x |x-y|
990     {
991       const ExprId e0 = expr.children.e0;
992       const ExprId e1 = expr.children.e1;
993       return disjSet(kind, buildLattices(e0, i), buildLattices(e1, i));
994     }
995   case TensorExp::Kind::kShrS:
996   case TensorExp::Kind::kShrU:
997   case TensorExp::Kind::kShlI:
998     // A shift operation by an invariant amount (viz. tensor expressions
999     // can only occur at the left-hand-side of the operator) can be handled
1000     // with the conjuction rule.
1001     {
1002       const ExprId e0 = expr.children.e0;
1003       const ExprId e1 = expr.children.e1;
1004       assert(isInvariant(e1));
1005       return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i));
1006     }
1007   case TensorExp::Kind::kBinary:
1008     // A custom binary operation.
1009     //
1010     //  x op y|   !y    |       y      |
1011     //  ------+---------+--------------+
1012     //    !x  |  empty  |   right(y)   |
1013     //     x  | left(x) | overlap(x,y) |
1014     {
1015       const ExprId e0 = expr.children.e0;
1016       const ExprId e1 = expr.children.e1;
1017       BinaryOp binop = cast<BinaryOp>(expr.op);
1018       const LatSetId child0 = buildLattices(e0, i);
1019       const LatSetId child1 = buildLattices(e1, i);
1020       Region &leftRegion = binop.getLeftRegion();
1021       Region &rightRegion = binop.getRightRegion();
1022       // Left Region.
1023       Operation *leftYield = nullptr;
1024       if (!leftRegion.empty()) {
1025         Block &leftBlock = leftRegion.front();
1026         leftYield = leftBlock.getTerminator();
1027       }
1028       // Right Region.
1029       Operation *rightYield = nullptr;
1030       if (!rightRegion.empty()) {
1031         Block &rightBlock = rightRegion.front();
1032         rightYield = rightBlock.getTerminator();
1033       }
1034       bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
1035       bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
1036       return combiSet(TensorExp::Kind::kBinary, child0, child1, binop,
1037                       includeLeft, TensorExp::Kind::kBinaryBranch, leftYield,
1038                       includeRight, TensorExp::Kind::kBinaryBranch, rightYield);
1039     }
1040   case TensorExp::Kind::kReduce:
1041     // A custom reduce operation.
1042     {
1043       const ExprId e0 = expr.children.e0;
1044       const ExprId e1 = expr.children.e1;
1045       Operation *const op = expr.op;
1046       return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i), op);
1047     }
1048   }
1049   llvm_unreachable("unexpected expression kind");
1050 }
1051 
1052 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
1053   // Build the linalg semantics backward from yield.
1054   Operation *yield = op.getRegion().front().getTerminator();
1055   assert(isa<linalg::YieldOp>(yield));
1056   return buildTensorExp(op, yield->getOperand(0));
1057 }
1058 
1059 /// Only returns false if we are certain this is a nonzero.
1060 bool Merger::maybeZero(ExprId e) const {
1061   const auto &expr = exp(e);
1062   if (expr.kind == TensorExp::Kind::kInvariant) {
1063     if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1064       ArrayAttr arrayAttr = c.getValue();
1065       return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1066              cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1067     }
1068     if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
1069       return c.value() == 0;
1070     if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>())
1071       return c.value().isZero();
1072   }
1073   return true;
1074 }
1075 
1076 Type Merger::inferType(ExprId e, Value src) const {
1077   // Obtain the destination type from the cast node.
1078   Type dtp = exp(e).val.getType();
1079   // Inspect source type. For vector types, apply the same
1080   // vectorization to the destination type.
1081   if (auto vtp = dyn_cast<VectorType>(src.getType()))
1082     return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
1083   return dtp;
1084 }
1085 
1086 /// Ensures that sparse compiler can generate code for expression.
1087 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
1088   // Arguments are always admissible.
1089   if (isa<BlockArgument>(v))
1090     return true;
1091   // Accept index anywhere.
1092   Operation *def = v.getDefiningOp();
1093   if (isa<linalg::IndexOp>(def))
1094     return true;
1095   // Operation defined outside branch.
1096   if (def->getBlock() != block)
1097     return def->getBlock() != op->getBlock(); // invariant?
1098   // Operation defined within branch. Anything is accepted,
1099   // as long as all subexpressions are admissible.
1100   for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
1101     if (!isAdmissibleBranchExp(op, block, def->getOperand(i)))
1102       return false;
1103   return true;
1104 }
1105 
1106 /// Ensures that sparse compiler can generate code for branch.
1107 static bool isAdmissibleBranch(Operation *op, Region &region) {
1108   if (region.empty())
1109     return true;
1110   // Build the semi-ring branch semantics backward from yield.
1111   Operation *yield = region.front().getTerminator();
1112   assert(isa<YieldOp>(yield));
1113   return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
1114 }
1115 
1116 std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1117   if (auto arg = dyn_cast<BlockArgument>(v)) {
1118     const TensorId tid = makeTensorId(arg.getArgNumber());
1119     // Any argument of the generic op that is not marked as a scalar
1120     // argument is considered a tensor, indexed by the implicit loop
1121     // bounds. This includes rank-0 tensor arguments.
1122     if (arg.getOwner()->getParentOp() == op) {
1123       OpOperand &t = op->getOpOperand(tid);
1124       if (!op.isScalar(&t))
1125         return addTensorExp(tid);
1126       v = t.get(); // get scalar value
1127     }
1128     // Any other argument (marked as scalar argument for the generic op
1129     // or belonging to an enveloping op) is considered invariant.
1130     return addInvariantExp(v);
1131   }
1132   // Something defined outside is invariant.
1133   Operation *def = v.getDefiningOp();
1134   if (def->getBlock() != &op.getRegion().front())
1135     return addInvariantExp(v);
1136   // Construct index operations.
1137   if (def->getNumOperands() == 0) {
1138     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
1139       return addLoopVarExp(makeLoopId(indexOp.getDim()));
1140   }
1141   // Construct unary operations if subexpression can be built.
1142   if (def->getNumOperands() == 1) {
1143     const auto x = buildTensorExp(op, def->getOperand(0));
1144     if (x.has_value()) {
1145       const ExprId e = *x;
1146       if (isa<math::AbsFOp>(def))
1147         return addExp(TensorExp::Kind::kAbsF, e);
1148       if (isa<complex::AbsOp>(def))
1149         return addExp(TensorExp::Kind::kAbsC, e);
1150       if (isa<math::AbsIOp>(def))
1151         return addExp(TensorExp::Kind::kAbsI, e);
1152       if (isa<math::CeilOp>(def))
1153         return addExp(TensorExp::Kind::kCeilF, e);
1154       if (isa<math::FloorOp>(def))
1155         return addExp(TensorExp::Kind::kFloorF, e);
1156       if (isa<math::SqrtOp>(def))
1157         return addExp(TensorExp::Kind::kSqrtF, e);
1158       if (isa<complex::SqrtOp>(def))
1159         return addExp(TensorExp::Kind::kSqrtC, e);
1160       if (isa<math::ExpM1Op>(def))
1161         return addExp(TensorExp::Kind::kExpm1F, e);
1162       if (isa<complex::Expm1Op>(def))
1163         return addExp(TensorExp::Kind::kExpm1C, e);
1164       if (isa<math::Log1pOp>(def))
1165         return addExp(TensorExp::Kind::kLog1pF, e);
1166       if (isa<complex::Log1pOp>(def))
1167         return addExp(TensorExp::Kind::kLog1pC, e);
1168       if (isa<math::SinOp>(def))
1169         return addExp(TensorExp::Kind::kSinF, e);
1170       if (isa<complex::SinOp>(def))
1171         return addExp(TensorExp::Kind::kSinC, e);
1172       if (isa<math::TanhOp>(def))
1173         return addExp(TensorExp::Kind::kTanhF, e);
1174       if (isa<complex::TanhOp>(def))
1175         return addExp(TensorExp::Kind::kTanhC, e);
1176       if (isa<arith::NegFOp>(def))
1177         return addExp(TensorExp::Kind::kNegF, e); // no negi in std
1178       if (isa<complex::NegOp>(def))
1179         return addExp(TensorExp::Kind::kNegC, e);
1180       if (isa<arith::TruncFOp>(def))
1181         return addExp(TensorExp::Kind::kTruncF, e, v);
1182       if (isa<arith::ExtFOp>(def))
1183         return addExp(TensorExp::Kind::kExtF, e, v);
1184       if (isa<arith::FPToSIOp>(def))
1185         return addExp(TensorExp::Kind::kCastFS, e, v);
1186       if (isa<arith::FPToUIOp>(def))
1187         return addExp(TensorExp::Kind::kCastFU, e, v);
1188       if (isa<arith::SIToFPOp>(def))
1189         return addExp(TensorExp::Kind::kCastSF, e, v);
1190       if (isa<arith::UIToFPOp>(def))
1191         return addExp(TensorExp::Kind::kCastUF, e, v);
1192       if (isa<arith::ExtSIOp>(def))
1193         return addExp(TensorExp::Kind::kCastS, e, v);
1194       if (isa<arith::ExtUIOp>(def))
1195         return addExp(TensorExp::Kind::kCastU, e, v);
1196       if (isa<arith::IndexCastOp>(def))
1197         return addExp(TensorExp::Kind::kCastIdx, e, v);
1198       if (isa<arith::TruncIOp>(def))
1199         return addExp(TensorExp::Kind::kTruncI, e, v);
1200       if (isa<complex::ImOp>(def))
1201         return addExp(TensorExp::Kind::kCIm, e);
1202       if (isa<complex::ReOp>(def))
1203         return addExp(TensorExp::Kind::kCRe, e);
1204       if (isa<arith::BitcastOp>(def))
1205         return addExp(TensorExp::Kind::kBitCast, e, v);
1206       if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1207         if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
1208             isAdmissibleBranch(unop, unop.getAbsentRegion()))
1209           return addExp(TensorExp::Kind::kUnary, e, Value(), def);
1210       }
1211       if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1212         if (isAdmissibleBranch(selop, selop.getRegion()))
1213           return addExp(TensorExp::Kind::kSelect, e, Value(), def);
1214       }
1215     }
1216   }
1217   // Construct binary operations if subexpressions can be built.
1218   // See buildLattices() for an explanation of rejecting certain
1219   // division and shift operations.
1220   if (def->getNumOperands() == 2) {
1221     const auto x = buildTensorExp(op, def->getOperand(0));
1222     const auto y = buildTensorExp(op, def->getOperand(1));
1223     if (x.has_value() && y.has_value()) {
1224       const ExprId e0 = *x;
1225       const ExprId e1 = *y;
1226       if (isa<arith::MulFOp>(def))
1227         return addExp(TensorExp::Kind::kMulF, e0, e1);
1228       if (isa<complex::MulOp>(def))
1229         return addExp(TensorExp::Kind::kMulC, e0, e1);
1230       if (isa<arith::MulIOp>(def))
1231         return addExp(TensorExp::Kind::kMulI, e0, e1);
1232       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1233         return addExp(TensorExp::Kind::kDivF, e0, e1);
1234       if (isa<complex::DivOp>(def) && !maybeZero(e1))
1235         return addExp(TensorExp::Kind::kDivC, e0, e1);
1236       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1237         return addExp(TensorExp::Kind::kDivS, e0, e1);
1238       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1239         return addExp(TensorExp::Kind::kDivU, e0, e1);
1240       if (isa<arith::AddFOp>(def))
1241         return addExp(TensorExp::Kind::kAddF, e0, e1);
1242       if (isa<complex::AddOp>(def))
1243         return addExp(TensorExp::Kind::kAddC, e0, e1);
1244       if (isa<arith::AddIOp>(def))
1245         return addExp(TensorExp::Kind::kAddI, e0, e1);
1246       if (isa<arith::SubFOp>(def))
1247         return addExp(TensorExp::Kind::kSubF, e0, e1);
1248       if (isa<complex::SubOp>(def))
1249         return addExp(TensorExp::Kind::kSubC, e0, e1);
1250       if (isa<arith::SubIOp>(def))
1251         return addExp(TensorExp::Kind::kSubI, e0, e1);
1252       if (isa<arith::AndIOp>(def))
1253         return addExp(TensorExp::Kind::kAndI, e0, e1);
1254       if (isa<arith::OrIOp>(def))
1255         return addExp(TensorExp::Kind::kOrI, e0, e1);
1256       if (isa<arith::XOrIOp>(def))
1257         return addExp(TensorExp::Kind::kXorI, e0, e1);
1258       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1259         return addExp(TensorExp::Kind::kShrS, e0, e1);
1260       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1261         return addExp(TensorExp::Kind::kShrU, e0, e1);
1262       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1263         return addExp(TensorExp::Kind::kShlI, e0, e1);
1264       if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1265         if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
1266             (binop.getLeftIdentity() ||
1267              isAdmissibleBranch(binop, binop.getLeftRegion())) &&
1268             (binop.getRightIdentity() ||
1269              isAdmissibleBranch(binop, binop.getRightRegion())))
1270           return addExp(TensorExp::Kind::kBinary, e0, e1, def);
1271       }
1272     }
1273   }
1274   // Construct ternary operations if subexpressions can be built.
1275   if (def->getNumOperands() == 3) {
1276     const auto x = buildTensorExp(op, def->getOperand(0));
1277     const auto y = buildTensorExp(op, def->getOperand(1));
1278     const auto z = buildTensorExp(op, def->getOperand(2));
1279     if (x.has_value() && y.has_value() && z.has_value()) {
1280       const ExprId e0 = *x;
1281       const ExprId e1 = *y;
1282       if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1283         if (isAdmissibleBranch(redop, redop.getRegion()))
1284           return addExp(TensorExp::Kind::kReduce, e0, e1, def);
1285       }
1286     }
1287   }
1288   // Cannot build.
1289   return std::nullopt;
1290 }
1291 
1292 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1293                            ValueRange vals) {
1294   // Make a clone of overlap region.
1295   Region tmpRegion;
1296   IRMapping mapper;
1297   region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1298   Block &clonedBlock = tmpRegion.front();
1299   YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1300   // Merge cloned block and return yield value.
1301   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1302   rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
1303   Value val = clonedYield.getResult();
1304   rewriter.eraseOp(clonedYield);
1305   rewriter.eraseOp(placeholder);
1306   return val;
1307 }
1308 
1309 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
1310                                Operation *op, Value v0) {
1311   if (!v0)
1312     // Empty input value must be propagated.
1313     return Value();
1314   UnaryOp unop = cast<UnaryOp>(op);
1315   Region &presentRegion = unop.getPresentRegion();
1316   if (presentRegion.empty())
1317     // Uninitialized Value() will be interpreted as missing data in the
1318     // output.
1319     return Value();
1320   return insertYieldOp(rewriter, loc, presentRegion, {v0});
1321 }
1322 
1323 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
1324                                 Operation *op, Value v0, Value v1) {
1325   if (!v0 || !v1)
1326     // Empty input values must be propagated.
1327     return Value();
1328   BinaryOp binop = cast<BinaryOp>(op);
1329   Region &overlapRegion = binop.getOverlapRegion();
1330   if (overlapRegion.empty())
1331     // Uninitialized Value() will be interpreted as missing data in the
1332     // output.
1333     return Value();
1334   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1335 }
1336 
1337 Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
1338                        Value v1) const {
1339   const auto &expr = exp(e);
1340   switch (expr.kind) {
1341   // Leaf.
1342   case TensorExp::Kind::kTensor:
1343   case TensorExp::Kind::kInvariant:
1344   case TensorExp::Kind::kLoopVar:
1345     llvm_unreachable("unexpected non-op");
1346   // Unary operations.
1347   case TensorExp::Kind::kAbsF:
1348     return rewriter.create<math::AbsFOp>(loc, v0);
1349   case TensorExp::Kind::kAbsC: {
1350     auto type = cast<ComplexType>(v0.getType());
1351     auto eltType = cast<FloatType>(type.getElementType());
1352     return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1353   }
1354   case TensorExp::Kind::kAbsI:
1355     return rewriter.create<math::AbsIOp>(loc, v0);
1356   case TensorExp::Kind::kCeilF:
1357     return rewriter.create<math::CeilOp>(loc, v0);
1358   case TensorExp::Kind::kFloorF:
1359     return rewriter.create<math::FloorOp>(loc, v0);
1360   case TensorExp::Kind::kSqrtF:
1361     return rewriter.create<math::SqrtOp>(loc, v0);
1362   case TensorExp::Kind::kSqrtC:
1363     return rewriter.create<complex::SqrtOp>(loc, v0);
1364   case TensorExp::Kind::kExpm1F:
1365     return rewriter.create<math::ExpM1Op>(loc, v0);
1366   case TensorExp::Kind::kExpm1C:
1367     return rewriter.create<complex::Expm1Op>(loc, v0);
1368   case TensorExp::Kind::kLog1pF:
1369     return rewriter.create<math::Log1pOp>(loc, v0);
1370   case TensorExp::Kind::kLog1pC:
1371     return rewriter.create<complex::Log1pOp>(loc, v0);
1372   case TensorExp::Kind::kSinF:
1373     return rewriter.create<math::SinOp>(loc, v0);
1374   case TensorExp::Kind::kSinC:
1375     return rewriter.create<complex::SinOp>(loc, v0);
1376   case TensorExp::Kind::kTanhF:
1377     return rewriter.create<math::TanhOp>(loc, v0);
1378   case TensorExp::Kind::kTanhC:
1379     return rewriter.create<complex::TanhOp>(loc, v0);
1380   case TensorExp::Kind::kNegF:
1381     return rewriter.create<arith::NegFOp>(loc, v0);
1382   case TensorExp::Kind::kNegC:
1383     return rewriter.create<complex::NegOp>(loc, v0);
1384   case TensorExp::Kind::kNegI: // no negi in std
1385     return rewriter.create<arith::SubIOp>(
1386         loc,
1387         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1388                                            rewriter.getZeroAttr(v0.getType())),
1389         v0);
1390   case TensorExp::Kind::kTruncF:
1391     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1392   case TensorExp::Kind::kExtF:
1393     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1394   case TensorExp::Kind::kCastFS:
1395     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1396   case TensorExp::Kind::kCastFU:
1397     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1398   case TensorExp::Kind::kCastSF:
1399     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1400   case TensorExp::Kind::kCastUF:
1401     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1402   case TensorExp::Kind::kCastS:
1403     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1404   case TensorExp::Kind::kCastU:
1405     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1406   case TensorExp::Kind::kCastIdx:
1407     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1408   case TensorExp::Kind::kTruncI:
1409     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1410   case TensorExp::Kind::kCIm: {
1411     auto type = cast<ComplexType>(v0.getType());
1412     auto eltType = cast<FloatType>(type.getElementType());
1413     return rewriter.create<complex::ImOp>(loc, eltType, v0);
1414   }
1415   case TensorExp::Kind::kCRe: {
1416     auto type = cast<ComplexType>(v0.getType());
1417     auto eltType = cast<FloatType>(type.getElementType());
1418     return rewriter.create<complex::ReOp>(loc, eltType, v0);
1419   }
1420   case TensorExp::Kind::kBitCast:
1421     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1422   // Binary operations.
1423   case TensorExp::Kind::kMulF:
1424     return rewriter.create<arith::MulFOp>(loc, v0, v1);
1425   case TensorExp::Kind::kMulC:
1426     return rewriter.create<complex::MulOp>(loc, v0, v1);
1427   case TensorExp::Kind::kMulI:
1428     return rewriter.create<arith::MulIOp>(loc, v0, v1);
1429   case TensorExp::Kind::kDivF:
1430     return rewriter.create<arith::DivFOp>(loc, v0, v1);
1431   case TensorExp::Kind::kDivC:
1432     return rewriter.create<complex::DivOp>(loc, v0, v1);
1433   case TensorExp::Kind::kDivS:
1434     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1435   case TensorExp::Kind::kDivU:
1436     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1437   case TensorExp::Kind::kAddF:
1438     return rewriter.create<arith::AddFOp>(loc, v0, v1);
1439   case TensorExp::Kind::kAddC:
1440     return rewriter.create<complex::AddOp>(loc, v0, v1);
1441   case TensorExp::Kind::kAddI:
1442     return rewriter.create<arith::AddIOp>(loc, v0, v1);
1443   case TensorExp::Kind::kSubF:
1444     return rewriter.create<arith::SubFOp>(loc, v0, v1);
1445   case TensorExp::Kind::kSubC:
1446     return rewriter.create<complex::SubOp>(loc, v0, v1);
1447   case TensorExp::Kind::kSubI:
1448     return rewriter.create<arith::SubIOp>(loc, v0, v1);
1449   case TensorExp::Kind::kAndI:
1450     return rewriter.create<arith::AndIOp>(loc, v0, v1);
1451   case TensorExp::Kind::kOrI:
1452     return rewriter.create<arith::OrIOp>(loc, v0, v1);
1453   case TensorExp::Kind::kXorI:
1454     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1455   case TensorExp::Kind::kShrS:
1456     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1457   case TensorExp::Kind::kShrU:
1458     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1459   case TensorExp::Kind::kShlI:
1460     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1461   case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
1462     return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(),
1463                          {v0});
1464   case TensorExp::Kind::kUnary:
1465     return buildUnaryPresent(rewriter, loc, expr.op, v0);
1466   case TensorExp::Kind::kSelect:
1467     return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
1468                          {v0});
1469   case TensorExp::Kind::kBinary:
1470     return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
1471   case TensorExp::Kind::kReduce: {
1472     ReduceOp redOp = cast<ReduceOp>(expr.op);
1473     return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1474   }
1475   }
1476   llvm_unreachable("unexpected expression kind in build");
1477 }
1478 
1479 } // namespace sparse_tensor
1480 } // namespace mlir
1481