xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (revision faf7cd97d07c6c2242a96f5daf94fb05cda0c115)
1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 #include <optional>
18 
19 namespace mlir {
20 namespace sparse_tensor {
21 
22 enum class ExpArity {
23   kNullary,
24   kUnary,
25   kBinary,
26 };
27 
28 static ExpArity getExpArity(TensorExp::Kind k) {
29   switch (k) {
30   // Leaf.
31   case TensorExp::Kind::kTensor:
32   case TensorExp::Kind::kInvariant:
33   case TensorExp::Kind::kLoopVar:
34   case TensorExp::Kind::kSynZero:
35     return ExpArity::kNullary;
36   case TensorExp::Kind::kAbsF:
37   case TensorExp::Kind::kAbsC:
38   case TensorExp::Kind::kAbsI:
39   case TensorExp::Kind::kCeilF:
40   case TensorExp::Kind::kFloorF:
41   case TensorExp::Kind::kSqrtF:
42   case TensorExp::Kind::kSqrtC:
43   case TensorExp::Kind::kExpm1F:
44   case TensorExp::Kind::kExpm1C:
45   case TensorExp::Kind::kLog1pF:
46   case TensorExp::Kind::kLog1pC:
47   case TensorExp::Kind::kSinF:
48   case TensorExp::Kind::kSinC:
49   case TensorExp::Kind::kTanhF:
50   case TensorExp::Kind::kTanhC:
51   case TensorExp::Kind::kTruncF:
52   case TensorExp::Kind::kExtF:
53   case TensorExp::Kind::kCastFS:
54   case TensorExp::Kind::kCastFU:
55   case TensorExp::Kind::kCastSF:
56   case TensorExp::Kind::kCastUF:
57   case TensorExp::Kind::kCastS:
58   case TensorExp::Kind::kCastU:
59   case TensorExp::Kind::kCastIdx:
60   case TensorExp::Kind::kTruncI:
61   case TensorExp::Kind::kCIm:
62   case TensorExp::Kind::kCRe:
63   case TensorExp::Kind::kBitCast:
64   case TensorExp::Kind::kBinaryBranch:
65   case TensorExp::Kind::kUnary:
66   case TensorExp::Kind::kSelect:
67   case TensorExp::Kind::kNegF:
68   case TensorExp::Kind::kNegC:
69   case TensorExp::Kind::kNegI:
70     return ExpArity::kUnary;
71   // Binary operations.
72   case TensorExp::Kind::kDivF:
73   case TensorExp::Kind::kDivC:
74   case TensorExp::Kind::kDivS:
75   case TensorExp::Kind::kDivU:
76   case TensorExp::Kind::kShrS:
77   case TensorExp::Kind::kShrU:
78   case TensorExp::Kind::kShlI:
79   case TensorExp::Kind::kMulF:
80   case TensorExp::Kind::kMulC:
81   case TensorExp::Kind::kMulI:
82   case TensorExp::Kind::kAndI:
83   case TensorExp::Kind::kAddF:
84   case TensorExp::Kind::kAddC:
85   case TensorExp::Kind::kAddI:
86   case TensorExp::Kind::kOrI:
87   case TensorExp::Kind::kXorI:
88   case TensorExp::Kind::kBinary:
89   case TensorExp::Kind::kReduce:
90   case TensorExp::Kind::kSubF:
91   case TensorExp::Kind::kSubC:
92   case TensorExp::Kind::kSubI:
93   case TensorExp::Kind::kCmpF:
94   case TensorExp::Kind::kCmpI:
95     return ExpArity::kBinary;
96   }
97   llvm_unreachable("unexpected kind");
98 }
99 
100 //===----------------------------------------------------------------------===//
101 // Constructors.
102 //===----------------------------------------------------------------------===//
103 
104 TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
105                      Operation *o, Attribute a)
106     : kind(k), val(v), op(o) {
107   switch (kind) {
108   // Leaf.
109   case TensorExp::Kind::kTensor:
110     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
111     tensor = x;
112     return;
113   case TensorExp::Kind::kSynZero:
114     assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o);
115     return;
116   case TensorExp::Kind::kInvariant:
117     assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o);
118     return;
119   case TensorExp::Kind::kLoopVar:
120     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
121     loop = x;
122     return;
123   // Unary operations.
124   case TensorExp::Kind::kAbsF:
125   case TensorExp::Kind::kAbsC:
126   case TensorExp::Kind::kAbsI:
127   case TensorExp::Kind::kCeilF:
128   case TensorExp::Kind::kFloorF:
129   case TensorExp::Kind::kSqrtF:
130   case TensorExp::Kind::kSqrtC:
131   case TensorExp::Kind::kExpm1F:
132   case TensorExp::Kind::kExpm1C:
133   case TensorExp::Kind::kLog1pF:
134   case TensorExp::Kind::kLog1pC:
135   case TensorExp::Kind::kSinF:
136   case TensorExp::Kind::kSinC:
137   case TensorExp::Kind::kTanhF:
138   case TensorExp::Kind::kTanhC:
139   case TensorExp::Kind::kNegF:
140   case TensorExp::Kind::kNegC:
141   case TensorExp::Kind::kNegI:
142   case TensorExp::Kind::kCIm:
143   case TensorExp::Kind::kCRe:
144     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
145     children.e0 = x;
146     children.e1 = y;
147     return;
148   case TensorExp::Kind::kTruncF:
149   case TensorExp::Kind::kExtF:
150   case TensorExp::Kind::kCastFS:
151   case TensorExp::Kind::kCastFU:
152   case TensorExp::Kind::kCastSF:
153   case TensorExp::Kind::kCastUF:
154   case TensorExp::Kind::kCastS:
155   case TensorExp::Kind::kCastU:
156   case TensorExp::Kind::kCastIdx:
157   case TensorExp::Kind::kTruncI:
158   case TensorExp::Kind::kBitCast:
159     assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o);
160     children.e0 = x;
161     children.e1 = y;
162     return;
163   case TensorExp::Kind::kBinaryBranch:
164   case TensorExp::Kind::kSelect:
165     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o);
166     children.e0 = x;
167     children.e1 = y;
168     return;
169   case TensorExp::Kind::kUnary:
170     // No assertion on y can be made, as the branching paths involve both
171     // a unary (`mapSet`) and binary (`disjSet`) pathway.
172     assert(x != detail::kInvalidId && !v && o);
173     children.e0 = x;
174     children.e1 = y;
175     return;
176   // Binary operations.
177   case TensorExp::Kind::kMulF:
178   case TensorExp::Kind::kMulC:
179   case TensorExp::Kind::kMulI:
180   case TensorExp::Kind::kDivF:
181   case TensorExp::Kind::kDivC:
182   case TensorExp::Kind::kDivS:
183   case TensorExp::Kind::kDivU:
184   case TensorExp::Kind::kAddF:
185   case TensorExp::Kind::kAddC:
186   case TensorExp::Kind::kAddI:
187   case TensorExp::Kind::kSubF:
188   case TensorExp::Kind::kSubC:
189   case TensorExp::Kind::kSubI:
190   case TensorExp::Kind::kAndI:
191   case TensorExp::Kind::kOrI:
192   case TensorExp::Kind::kXorI:
193   case TensorExp::Kind::kShrS:
194   case TensorExp::Kind::kShrU:
195   case TensorExp::Kind::kShlI:
196     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
197     children.e0 = x;
198     children.e1 = y;
199     return;
200   case TensorExp::Kind::kCmpF:
201   case TensorExp::Kind::kCmpI:
202     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
203     attr = a;
204     children.e0 = x;
205     children.e1 = y;
206     return;
207   case TensorExp::Kind::kBinary:
208   case TensorExp::Kind::kReduce:
209     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o);
210     children.e0 = x;
211     children.e1 = y;
212     return;
213   }
214   llvm_unreachable("unexpected kind");
215 }
216 
217 Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
218                unsigned numFilterLoops, unsigned maxLvlRank)
219     : outTensor(numInputOutputTensors - 1),
220       syntheticTensor(numInputOutputTensors),
221       numTensors(numInputOutputTensors + 1), numNativeLoops(numNativeLoops),
222       numLoops(numNativeLoops + numFilterLoops), hasSparseOut(false),
223       lvlTypes(numTensors,
224                std::vector<DimLevelType>(numLoops, DimLevelType::Undef)),
225       loopToLvl(numTensors,
226                 std::vector<std::optional<Level>>(numLoops, std::nullopt)),
227       lvlToLoop(numTensors,
228                 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
229       loopToDependencies(
230           numLoops, std::vector<std::optional<std::pair<Level, DimLevelType>>>(
231                         numTensors, std::nullopt)),
232       levelToDependentLoop(numTensors, std::vector<std::vector<LoopId>>(
233                                            maxLvlRank, std::vector<LoopId>())),
234       loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
235 
236 //===----------------------------------------------------------------------===//
237 // Lattice methods.
238 //===----------------------------------------------------------------------===//
239 
240 ExprId Merger::addTensorExp(TensorId t) {
241   assert(isValidTensorId(t));
242   const ExprId eNew(tensorExps.size());
243   tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId,
244                           Value(), nullptr, nullptr);
245   return eNew;
246 }
247 
248 ExprId Merger::addLoopVarExp(LoopId i) {
249   assert(isValidLoopId(i));
250   const ExprId eNew(tensorExps.size());
251   tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId,
252                           Value(), nullptr, nullptr);
253   return eNew;
254 }
255 
256 ExprId Merger::addInvariantExp(Value v) {
257   const ExprId eNew(tensorExps.size());
258   tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId,
259                           detail::kInvalidId, v, nullptr, nullptr);
260   return eNew;
261 }
262 
263 ExprId Merger::addSynZeroExp() {
264   const ExprId eNew(tensorExps.size());
265   tensorExps.emplace_back(TensorExp::Kind::kSynZero, detail::kInvalidId,
266                           detail::kInvalidId, Value(), nullptr, nullptr);
267   return eNew;
268 }
269 
270 ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op,
271                       Attribute attr) {
272   assert(k > TensorExp::Kind::kLoopVar);
273   const ExprId eNew(tensorExps.size());
274   tensorExps.emplace_back(k, e0, e1, Value(), op, attr);
275   return eNew;
276 }
277 
278 ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op,
279                       Attribute attr) {
280   assert(k > TensorExp::Kind::kLoopVar);
281   const ExprId eNew(tensorExps.size());
282   tensorExps.emplace_back(k, e, detail::kInvalidId, v, op, attr);
283   return eNew;
284 }
285 
286 LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {
287   const LatPointId pNew(latPoints.size());
288   const unsigned size = numLoops * numTensors;
289   const TensorLoopId b = makeTensorLoopId(t, i);
290   latPoints.emplace_back(size, e);
291   latPoints[pNew].bits.set(b);
292   return pNew;
293 }
294 
295 LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
296   assert(bits.size() == numLoops * numTensors);
297   const LatPointId pNew(latPoints.size());
298   latPoints.emplace_back(bits, e);
299   return pNew;
300 }
301 
302 LatSetId Merger::addSet() {
303   const LatSetId sNew(latSets.size());
304   latSets.emplace_back();
305   return sNew;
306 }
307 
308 LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1,
309                            Operation *op) {
310   TensorExp::Kind kind = exp(e).kind;
311   Attribute attr = exp(e).attr;
312   const LatPointId pNew(latPoints.size());
313   const auto &point0 = lat(p0);
314   const auto &point1 = lat(p1);
315   BitVector bits(point0.bits);
316   bits |= point1.bits;
317   const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr);
318   latPoints.emplace_back(bits, ne);
319   return pNew;
320 }
321 
322 LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
323   const LatSetId sNew = addSet();
324   auto &setNew = latSets[sNew];
325   for (const LatPointId p0 : set(s0))
326     for (const LatPointId p1 : set(s1))
327       setNew.push_back(conjLat(e, p0, p1, op));
328   return sNew;
329 }
330 
331 LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
332   const LatSetId sNew = conjSet(e, s0, s1, op);
333   TensorExp::Kind kind = exp(e).kind;
334 
335   // Followed by all in s0.
336   latSets[sNew].append(latSets[s0]);
337   // Map binary 0-y to unary -y.
338   // TODO: move this if-else logic into buildLattices
339   if (kind == TensorExp::Kind::kSubF)
340     s1 = mapSet(TensorExp::Kind::kNegF, s1);
341   else if (kind == TensorExp::Kind::kSubC)
342     s1 = mapSet(TensorExp::Kind::kNegC, s1);
343   else if (kind == TensorExp::Kind::kSubI)
344     s1 = mapSet(TensorExp::Kind::kNegI, s1);
345   // Followed by all in s1.
346   latSets[sNew].append(latSets[s1]);
347   return sNew;
348 }
349 
350 LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) {
351   assert(exp(e).kind == TensorExp::Kind::kCmpI ||
352          exp(e).kind == TensorExp::Kind::kCmpF);
353   const LatSetId sNew = conjSet(e, s0, s1, nullptr);
354 
355   ExprId e0 = exp(e).children.e0;
356   ExprId e1 = exp(e).children.e1;
357   if (exp(e0).kind == TensorExp::Kind::kSynZero ||
358       exp(e1).kind == TensorExp::Kind::kSynZero) {
359     // lhs and rhs can't be synthetic zero at the same time.
360     assert(exp(e0).kind != exp(e1).kind);
361     // If one of the operands has already been assigned to zero (the
362     // element is absent in the corresponding operand), then we do not
363     // need to build disjunctive set for it.
364     return sNew;
365   }
366 
367   auto lhsSet = mapBinWithSynZeroSet(e, s0, false);
368   auto rhsSet = mapBinWithSynZeroSet(e, s1, true);
369   latSets[sNew].append(latSets[lhsSet]);
370   latSets[sNew].append(latSets[rhsSet]);
371   return sNew;
372 }
373 
374 LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
375                           bool includeLeft, TensorExp::Kind ltrans,
376                           Operation *opleft, bool includeRight,
377                           TensorExp::Kind rtrans, Operation *opright) {
378   const LatSetId sNew = conjSet(e, s0, s1, orig);
379   // Left Region.
380   if (includeLeft) {
381     if (opleft)
382       s0 = mapSet(ltrans, s0, Value(), opleft);
383     latSets[sNew].append(latSets[s0]);
384   }
385   // Right Region.
386   if (includeRight) {
387     if (opright)
388       s1 = mapSet(rtrans, s1, Value(), opright);
389     latSets[sNew].append(latSets[s1]);
390   }
391   return sNew;
392 }
393 
394 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
395                         Operation *op) {
396   assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect);
397   const LatSetId sNew = addSet();
398   auto &setNew = latSets[sNew];
399   for (const LatPointId p : set(s0)) {
400     const auto &point = latPoints[p];
401     setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op)));
402   }
403   return sNew;
404 }
405 
406 LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) {
407   TensorExp::Kind kind = exp(e).kind;
408   Attribute a = exp(e).attr;
409   assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI);
410   // Must be a binary operation.
411   const LatSetId sNew = addSet();
412   auto &setNew = latSets[sNew];
413   const ExprId zeroExp = addSynZeroExp();
414   for (const LatPointId p : set(s0)) {
415     const auto &point = latPoints[p];
416     ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a)
417                             : addExp(kind, point.exp, zeroExp, nullptr, a);
418     setNew.push_back(addLat(point.bits, newExp));
419   }
420   return sNew;
421 }
422 
423 LatSetId Merger::optimizeSet(LatSetId s0) {
424   const LatSetId sNew = addSet();
425   auto &setNew = latSets[sNew];
426   const auto &set0 = set(s0);
427   assert(!set0.empty());
428   const LatPointId p0 = set0[0];
429   for (const LatPointId p1 : set0) {
430     bool add = true;
431     if (p0 != p1) {
432       // Check whether this is a straightforward copy.
433       if (expIsTensor(latPoints[p1].exp, outTensor))
434         continue;
435       // Check whether this conjunction is already covered.
436       for (const LatPointId p2 : setNew) {
437         assert(!latGT(p1, p2)); // Lj => Li would be bad
438         if (onlyDenseDiff(p2, p1)) {
439           add = false;
440           break;
441         }
442       }
443       assert(!add || latGT(p0, p1));
444     }
445     if (add)
446       setNew.push_back(p1);
447   }
448   for (const LatPointId p : setNew)
449     latPoints[p].simple = simplifyCond(sNew, p);
450   return sNew;
451 }
452 
453 BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
454   // First determine if this lattice point is a *singleton*, i.e.,
455   // the last point in a lattice, no other is less than this one.
456   bool isSingleton = true;
457   for (const LatPointId p1 : set(s0)) {
458     if (p0 != p1 && latGT(p0, p1)) {
459       isSingleton = false;
460       break;
461     }
462   }
463 
464   BitVector simple(latPoints[p0].bits);
465   bool reset = isSingleton && hasAnySparse(simple);
466   const TensorLoopId be = simple.size();
467   TensorLoopId offset = 0; // relative to the end
468   if (!reset)
469     // Starts resetting from a dense level, so that the first bit (if kept)
470     // is not undefined level-type.
471     for (unsigned b = 0; b < be; b++) {
472       if (simple[b] && isDenseDLT(getLvlType(TensorLoopId{b}))) {
473         offset = be - b - 1; // relative to the end
474         break;
475       }
476     }
477 
478   // Now apply the two basic rules. We also iterate the bits reversely to always
479   // keep the rightmost bit (which could possibly be a synthetic tensor).
480   for (unsigned b = be - 1 - offset, i = 0; i < be;
481        b = b == 0 ? be - 1 : b - 1, i++) {
482     // Slice on dense level has `locate` property as well, and can be optimized.
483     if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
484       const auto dlt = getLvlType(b);
485       if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) &&
486           !isCompressedWithHiDLT(dlt)) {
487         if (reset)
488           simple.reset(b);
489         reset = true;
490       }
491     }
492   }
493   return simple;
494 }
495 
496 bool Merger::latGT(LatPointId i, LatPointId j) const {
497   const BitVector &bitsi = lat(i).bits;
498   const BitVector &bitsj = lat(j).bits;
499   assert(bitsi.size() == bitsj.size());
500   if (bitsi.count() > bitsj.count()) {
501     for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
502       if (bitsj[b] && !bitsi[b])
503         return false;
504     return true;
505   }
506   return false;
507 }
508 
509 bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
510   BitVector tmp(latPoints[j].bits);
511   tmp ^= latPoints[i].bits;
512   return !hasAnySparse(tmp);
513 }
514 
515 bool Merger::expContainsTensor(ExprId e, TensorId t) const {
516   const auto &expr = exp(e);
517   // First we check `expIsTensor`.
518   if (expr.kind == TensorExp::Kind::kTensor)
519     return expr.tensor == t;
520 
521   switch (getExpArity(expr.kind)) {
522   case ExpArity::kNullary:
523     return false;
524   case ExpArity::kUnary: {
525     const ExprId e0 = expr.children.e0;
526     return expContainsTensor(e0, t);
527   }
528   case ExpArity::kBinary: {
529     const ExprId e0 = expr.children.e0;
530     const ExprId e1 = expr.children.e1;
531     return expContainsTensor(e0, t) || expContainsTensor(e1, t);
532   }
533   }
534   llvm_unreachable("unexpected arity");
535 }
536 
537 bool Merger::hasNegateOnOut(ExprId e) const {
538   const auto &expr = exp(e);
539   switch (expr.kind) {
540   case TensorExp::Kind::kNegF:
541   case TensorExp::Kind::kNegC:
542   case TensorExp::Kind::kNegI:
543     return expContainsTensor(expr.children.e0, outTensor);
544   case TensorExp::Kind::kSubF:
545   case TensorExp::Kind::kSubC:
546   case TensorExp::Kind::kSubI:
547     return expContainsTensor(expr.children.e1, outTensor) ||
548            hasNegateOnOut(expr.children.e0);
549   default: {
550     switch (getExpArity(expr.kind)) {
551     case ExpArity::kNullary:
552       return false;
553     case ExpArity::kUnary:
554       return hasNegateOnOut(expr.children.e0);
555     case ExpArity::kBinary:
556       return hasNegateOnOut(expr.children.e0) ||
557              hasNegateOnOut(expr.children.e1);
558     }
559   }
560   }
561   llvm_unreachable("unexpected kind");
562 }
563 
564 bool Merger::isSingleCondition(TensorId t, ExprId e) const {
565   assert(isValidTensorId(t));
566   const auto &expr = exp(e);
567   switch (expr.kind) {
568   // Leaf.
569   case TensorExp::Kind::kTensor:
570     return expr.tensor == t;
571   case TensorExp::Kind::kInvariant:
572   case TensorExp::Kind::kLoopVar:
573   case TensorExp::Kind::kSynZero:
574     return false;
575   // Unary operations.
576   case TensorExp::Kind::kAbsF:
577   case TensorExp::Kind::kAbsC:
578   case TensorExp::Kind::kAbsI:
579   case TensorExp::Kind::kCeilF:
580   case TensorExp::Kind::kFloorF:
581   case TensorExp::Kind::kSqrtF:
582   case TensorExp::Kind::kSqrtC:
583   case TensorExp::Kind::kExpm1F:
584   case TensorExp::Kind::kExpm1C:
585   case TensorExp::Kind::kLog1pF:
586   case TensorExp::Kind::kLog1pC:
587   case TensorExp::Kind::kSinF:
588   case TensorExp::Kind::kSinC:
589   case TensorExp::Kind::kTanhF:
590   case TensorExp::Kind::kTanhC:
591   case TensorExp::Kind::kNegF:
592   case TensorExp::Kind::kNegC:
593   case TensorExp::Kind::kNegI:
594   case TensorExp::Kind::kTruncF:
595   case TensorExp::Kind::kExtF:
596   case TensorExp::Kind::kCastFS:
597   case TensorExp::Kind::kCastFU:
598   case TensorExp::Kind::kCastSF:
599   case TensorExp::Kind::kCastUF:
600   case TensorExp::Kind::kCastS:
601   case TensorExp::Kind::kCastU:
602   case TensorExp::Kind::kCastIdx:
603   case TensorExp::Kind::kTruncI:
604   case TensorExp::Kind::kCIm:
605   case TensorExp::Kind::kCRe:
606   case TensorExp::Kind::kBitCast:
607   case TensorExp::Kind::kUnary:
608     return isSingleCondition(t, expr.children.e0);
609   case TensorExp::Kind::kBinaryBranch:
610   case TensorExp::Kind::kSelect:
611     return false;
612   // Binary operations.
613   case TensorExp::Kind::kDivF: // note: x / c only
614   case TensorExp::Kind::kDivC:
615   case TensorExp::Kind::kDivS:
616   case TensorExp::Kind::kDivU:
617     assert(!maybeZero(expr.children.e1));
618     return isSingleCondition(t, expr.children.e0);
619   case TensorExp::Kind::kShrS: // note: x >> inv only
620   case TensorExp::Kind::kShrU:
621   case TensorExp::Kind::kShlI:
622     assert(isInvariant(expr.children.e1));
623     return isSingleCondition(t, expr.children.e0);
624   case TensorExp::Kind::kMulF:
625   case TensorExp::Kind::kMulC:
626   case TensorExp::Kind::kMulI:
627   case TensorExp::Kind::kAndI:
628   case TensorExp::Kind::kReduce:
629     if (isSingleCondition(t, expr.children.e0))
630       return isSingleCondition(t, expr.children.e1) ||
631              isInvariant(expr.children.e1);
632     if (isSingleCondition(t, expr.children.e1))
633       return isInvariant(expr.children.e0);
634     return false;
635   case TensorExp::Kind::kAddF:
636   case TensorExp::Kind::kAddC:
637   case TensorExp::Kind::kAddI:
638     return isSingleCondition(t, expr.children.e0) &&
639            isSingleCondition(t, expr.children.e1);
640   case TensorExp::Kind::kSubF:
641   case TensorExp::Kind::kSubC:
642   case TensorExp::Kind::kSubI:
643   case TensorExp::Kind::kOrI:
644   case TensorExp::Kind::kXorI:
645   case TensorExp::Kind::kCmpF:
646   case TensorExp::Kind::kCmpI:
647   case TensorExp::Kind::kBinary:
648     return false;
649   }
650   llvm_unreachable("unexpected kind");
651 }
652 
653 bool Merger::hasAnySparse(const BitVector &bits) const {
654   for (TensorLoopId b : bits.set_bits()) {
655     const auto dlt = getLvlType(b);
656     if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
657         isCompressedWithHiDLT(dlt))
658       return true;
659   }
660   return hasSparseIdxReduction(bits);
661 }
662 
663 bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
664   for (TensorLoopId b : bits.set_bits())
665     if (isSparseLvlWithNonTrivialIdxExp(b))
666       return true;
667   return false;
668 }
669 
670 #ifndef NDEBUG
671 
672 //===----------------------------------------------------------------------===//
673 // Print methods (for debugging).
674 //===----------------------------------------------------------------------===//
675 
676 static const char *kindToOpSymbol(TensorExp::Kind kind) {
677   switch (kind) {
678   // Leaf.
679   case TensorExp::Kind::kTensor:
680     return "tensor";
681   case TensorExp::Kind::kInvariant:
682     return "invariant";
683   case TensorExp::Kind::kLoopVar:
684     return "index";
685   case TensorExp::Kind::kSynZero:
686     return "0";
687   // Unary operations.
688   case TensorExp::Kind::kAbsF:
689   case TensorExp::Kind::kAbsC:
690   case TensorExp::Kind::kAbsI:
691     return "abs";
692   case TensorExp::Kind::kCeilF:
693     return "ceil";
694   case TensorExp::Kind::kFloorF:
695     return "floor";
696   case TensorExp::Kind::kSqrtF:
697   case TensorExp::Kind::kSqrtC:
698     return "sqrt";
699   case TensorExp::Kind::kExpm1F:
700   case TensorExp::Kind::kExpm1C:
701     return "expm1";
702   case TensorExp::Kind::kLog1pF:
703   case TensorExp::Kind::kLog1pC:
704     return "log1p";
705   case TensorExp::Kind::kSinF:
706   case TensorExp::Kind::kSinC:
707     return "sin";
708   case TensorExp::Kind::kTanhF:
709   case TensorExp::Kind::kTanhC:
710     return "tanh";
711   case TensorExp::Kind::kNegF:
712   case TensorExp::Kind::kNegC:
713   case TensorExp::Kind::kNegI:
714     return "-";
715   case TensorExp::Kind::kTruncF:
716   case TensorExp::Kind::kExtF:
717   case TensorExp::Kind::kCastFS:
718   case TensorExp::Kind::kCastFU:
719   case TensorExp::Kind::kCastSF:
720   case TensorExp::Kind::kCastUF:
721   case TensorExp::Kind::kCastS:
722   case TensorExp::Kind::kCastU:
723   case TensorExp::Kind::kCastIdx:
724   case TensorExp::Kind::kTruncI:
725   case TensorExp::Kind::kCIm:
726     return "complex.im";
727   case TensorExp::Kind::kCRe:
728     return "complex.re";
729   case TensorExp::Kind::kBitCast:
730     return "cast";
731   case TensorExp::Kind::kBinaryBranch:
732     return "binary_branch";
733   case TensorExp::Kind::kUnary:
734     return "unary";
735   case TensorExp::Kind::kSelect:
736     return "select";
737   // Binary operations.
738   case TensorExp::Kind::kMulF:
739   case TensorExp::Kind::kMulC:
740   case TensorExp::Kind::kMulI:
741     return "*";
742   case TensorExp::Kind::kDivF:
743   case TensorExp::Kind::kDivC:
744   case TensorExp::Kind::kDivS:
745   case TensorExp::Kind::kDivU:
746     return "/";
747   case TensorExp::Kind::kAddF:
748   case TensorExp::Kind::kAddC:
749   case TensorExp::Kind::kAddI:
750     return "+";
751   case TensorExp::Kind::kSubF:
752   case TensorExp::Kind::kSubC:
753   case TensorExp::Kind::kSubI:
754     return "-";
755   case TensorExp::Kind::kAndI:
756     return "&";
757   case TensorExp::Kind::kOrI:
758     return "|";
759   case TensorExp::Kind::kXorI:
760     return "^";
761   case TensorExp::Kind::kShrS:
762     return "a>>";
763   case TensorExp::Kind::kShrU:
764     return ">>";
765   case TensorExp::Kind::kShlI:
766     return "<<";
767   case TensorExp::Kind::kCmpF:
768   case TensorExp::Kind::kCmpI:
769     return "cmp";
770   case TensorExp::Kind::kBinary:
771     return "binary";
772   case TensorExp::Kind::kReduce:
773     return "reduce";
774   }
775   llvm_unreachable("unexpected kind for symbol");
776 }
777 
778 void Merger::dumpExp(ExprId e) const {
779   const auto &expr = exp(e);
780   switch (expr.kind) {
781   // Leaf.
782   case TensorExp::Kind::kTensor:
783     if (expr.tensor == syntheticTensor)
784       llvm::dbgs() << "synthetic_";
785     else if (expr.tensor == outTensor)
786       llvm::dbgs() << "output_";
787     llvm::dbgs() << "tensor_" << expr.tensor;
788     break;
789   case TensorExp::Kind::kInvariant:
790     llvm::dbgs() << "invariant";
791     break;
792   case TensorExp::Kind::kSynZero:
793     llvm::dbgs() << "0";
794     break;
795   case TensorExp::Kind::kLoopVar:
796     llvm::dbgs() << "loopvar_" << expr.loop;
797     break;
798   // Unary operations.
799   case TensorExp::Kind::kAbsF:
800   case TensorExp::Kind::kAbsC:
801   case TensorExp::Kind::kAbsI:
802   case TensorExp::Kind::kCeilF:
803   case TensorExp::Kind::kFloorF:
804   case TensorExp::Kind::kSqrtF:
805   case TensorExp::Kind::kSqrtC:
806   case TensorExp::Kind::kExpm1F:
807   case TensorExp::Kind::kExpm1C:
808   case TensorExp::Kind::kLog1pF:
809   case TensorExp::Kind::kLog1pC:
810   case TensorExp::Kind::kSinF:
811   case TensorExp::Kind::kSinC:
812   case TensorExp::Kind::kTanhF:
813   case TensorExp::Kind::kTanhC:
814   case TensorExp::Kind::kNegF:
815   case TensorExp::Kind::kNegC:
816   case TensorExp::Kind::kNegI:
817   case TensorExp::Kind::kTruncF:
818   case TensorExp::Kind::kExtF:
819   case TensorExp::Kind::kCastFS:
820   case TensorExp::Kind::kCastFU:
821   case TensorExp::Kind::kCastSF:
822   case TensorExp::Kind::kCastUF:
823   case TensorExp::Kind::kCastS:
824   case TensorExp::Kind::kCastU:
825   case TensorExp::Kind::kCastIdx:
826   case TensorExp::Kind::kTruncI:
827   case TensorExp::Kind::kCIm:
828   case TensorExp::Kind::kCRe:
829   case TensorExp::Kind::kBitCast:
830   case TensorExp::Kind::kBinaryBranch:
831   case TensorExp::Kind::kUnary:
832   case TensorExp::Kind::kSelect:
833     llvm::dbgs() << kindToOpSymbol(expr.kind) << " ";
834     dumpExp(expr.children.e0);
835     break;
836   // Binary operations.
837   case TensorExp::Kind::kMulF:
838   case TensorExp::Kind::kMulC:
839   case TensorExp::Kind::kMulI:
840   case TensorExp::Kind::kDivF:
841   case TensorExp::Kind::kDivC:
842   case TensorExp::Kind::kDivS:
843   case TensorExp::Kind::kDivU:
844   case TensorExp::Kind::kAddF:
845   case TensorExp::Kind::kAddC:
846   case TensorExp::Kind::kAddI:
847   case TensorExp::Kind::kSubF:
848   case TensorExp::Kind::kSubC:
849   case TensorExp::Kind::kSubI:
850   case TensorExp::Kind::kAndI:
851   case TensorExp::Kind::kOrI:
852   case TensorExp::Kind::kXorI:
853   case TensorExp::Kind::kShrS:
854   case TensorExp::Kind::kShrU:
855   case TensorExp::Kind::kShlI:
856   case TensorExp::Kind::kCmpF:
857   case TensorExp::Kind::kCmpI:
858   case TensorExp::Kind::kBinary:
859   case TensorExp::Kind::kReduce:
860     llvm::dbgs() << "(";
861     dumpExp(expr.children.e0);
862     llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
863     if (expr.attr)
864       llvm::dbgs() << "{" << expr.attr << "}";
865     llvm::dbgs() << " ";
866     dumpExp(expr.children.e1);
867     llvm::dbgs() << ")";
868     break;
869   }
870 }
871 
872 void Merger::dumpLat(LatPointId p) const {
873   const auto &point = lat(p);
874   llvm::dbgs() << "lat(";
875   dumpBits(point.bits);
876   llvm::dbgs() << " :";
877   dumpBits(point.simple);
878   llvm::dbgs() << " : ";
879   dumpExp(point.exp);
880   llvm::dbgs() << " )\n";
881 }
882 
883 void Merger::dumpSet(LatSetId s) const {
884   const auto &ss = set(s);
885   llvm::dbgs() << "{ #" << ss.size() << "\n";
886   for (const LatPointId p : ss) {
887     llvm::dbgs() << "  ";
888     dumpLat(p);
889   }
890   llvm::dbgs() << "}\n";
891 }
892 
893 void Merger::dumpBits(const BitVector &bits) const {
894   for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
895     if (bits[b]) {
896       const TensorId t = tensor(b);
897       const LoopId i = loop(b);
898       const auto dlt = lvlTypes[t][i];
899       if (isLvlWithNonTrivialIdxExp(b))
900         llvm::dbgs() << " DEP_" << t << "_" << i;
901       else
902         llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt);
903     }
904   }
905 }
906 
907 #endif // NDEBUG
908 
909 //===----------------------------------------------------------------------===//
910 // Builder methods.
911 //===----------------------------------------------------------------------===//
912 
913 LatSetId Merger::buildLattices(ExprId e, LoopId i) {
914   // NOTE: The `expr` reference will be invalidated by recursive calls
915   // (and any other method that may add new expressions); therefore, the
916   // code below must make sure to copy fields of `expr` into local variables
917   // before making any recursive calls.
918   const auto &expr = exp(e);
919   const TensorExp::Kind kind = expr.kind;
920   switch (kind) {
921   // Leaf.
922   case TensorExp::Kind::kTensor:
923   case TensorExp::Kind::kInvariant:
924   case TensorExp::Kind::kSynZero:
925   case TensorExp::Kind::kLoopVar: {
926     // Either the loop-var is really used in the tensor expression, or it is
927     // set to the undefined loop-var in that level. An invariant expression,
928     // a proper index value, and a truly dynamic sparse output tensor are set
929     // to a synthetic tensor with undefined indices only to ensure the
930     // iteration space is not skipped as a result of their contents.
931     const LatSetId s = addSet();
932     TensorId t = syntheticTensor;
933     if (kind == TensorExp::Kind::kTensor) {
934       t = expr.tensor;
935       if (hasSparseOut && t == outTensor)
936         t = syntheticTensor;
937     }
938     latSets[s].push_back(addLat(t, i, e));
939     return s;
940   }
941   // Unary operations.
942   case TensorExp::Kind::kAbsF:
943   case TensorExp::Kind::kAbsC:
944   case TensorExp::Kind::kAbsI:
945   case TensorExp::Kind::kCeilF:
946   case TensorExp::Kind::kFloorF:
947   case TensorExp::Kind::kSqrtF:
948   case TensorExp::Kind::kSqrtC:
949   case TensorExp::Kind::kExpm1F:
950   case TensorExp::Kind::kExpm1C:
951   case TensorExp::Kind::kLog1pF:
952   case TensorExp::Kind::kLog1pC:
953   case TensorExp::Kind::kSinF:
954   case TensorExp::Kind::kSinC:
955   case TensorExp::Kind::kTanhF:
956   case TensorExp::Kind::kTanhC:
957   case TensorExp::Kind::kNegF:
958   case TensorExp::Kind::kNegC:
959   case TensorExp::Kind::kNegI:
960   case TensorExp::Kind::kTruncF:
961   case TensorExp::Kind::kExtF:
962   case TensorExp::Kind::kCastFS:
963   case TensorExp::Kind::kCastFU:
964   case TensorExp::Kind::kCastSF:
965   case TensorExp::Kind::kCastUF:
966   case TensorExp::Kind::kCastS:
967   case TensorExp::Kind::kCastU:
968   case TensorExp::Kind::kCastIdx:
969   case TensorExp::Kind::kTruncI:
970   case TensorExp::Kind::kCIm:
971   case TensorExp::Kind::kCRe:
972   case TensorExp::Kind::kBitCast:
973     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
974     // lattice set of the operand through the operator into a new set.
975     //
976     //  -y|!y | y |
977     //  --+---+---+
978     //    | 0 |-y |
979     {
980       const ExprId e0 = expr.children.e0;
981       const Value v = expr.val;
982       return mapSet(kind, buildLattices(e0, i), v);
983     }
984   case TensorExp::Kind::kBinaryBranch:
985   case TensorExp::Kind::kSelect:
986     // The left or right half of a binary operation which has already
987     // been split into separate operations for each region.
988     {
989       const ExprId e0 = expr.children.e0;
990       Operation *const op = expr.op;
991       return mapSet(kind, buildLattices(e0, i), Value(), op);
992     }
993   case TensorExp::Kind::kUnary:
994     // A custom unary operation.
995     //
996     //  op y|    !y    |     y      |
997     //  ----+----------+------------+
998     //      | absent() | present(y) |
999     {
1000       const ExprId e0 = expr.children.e0;
1001       UnaryOp unop = cast<UnaryOp>(expr.op);
1002       const LatSetId child0 = buildLattices(e0, i);
1003       Region &absentRegion = unop.getAbsentRegion();
1004       if (absentRegion.empty()) {
1005         // Simple mapping over existing values.
1006         return mapSet(kind, child0, Value(), unop);
1007       }
1008       // Use a disjunction with `unop` on the left and the absent value as an
1009       // invariant on the right.
1010       Block &absentBlock = absentRegion.front();
1011       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
1012       const Value absentVal = absentYield.getResult();
1013       const ExprId rhs = addInvariantExp(absentVal);
1014       return disjSet(e, child0, buildLattices(rhs, i), unop);
1015     }
1016   // Binary operations.
1017   case TensorExp::Kind::kMulF:
1018   case TensorExp::Kind::kMulC:
1019   case TensorExp::Kind::kMulI:
1020   case TensorExp::Kind::kAndI:
1021     // A multiplicative operation only needs to be performed
1022     // for the conjunction of sparse iteration spaces.
1023     //
1024     //  x*y|!y | y |
1025     //  ---+---+---+
1026     //  !x | 0 | 0 |
1027     //   x | 0 |x*y|
1028     //
1029     // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
1030     {
1031       const ExprId e0 = expr.children.e0;
1032       const ExprId e1 = expr.children.e1;
1033       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1034     }
1035   case TensorExp::Kind::kDivF:
1036   case TensorExp::Kind::kDivC:
1037   case TensorExp::Kind::kDivS:
1038   case TensorExp::Kind::kDivU:
1039     // A division is tricky, since 0/0, 0/c, c/0 all have
1040     // specific outcomes for floating-point and integers.
1041     // Thus, we need to traverse the full iteration space.
1042     //
1043     //  x/y|!y | y |
1044     //  ---+---+---+
1045     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
1046     //   x |x/0|x/y|  INT: x/0=exception for any x
1047     //
1048     // TODO: for now we "fixed" this by only accepting x/c cases
1049     //       during expression building, so that the conjunction
1050     //       rules applies (viz. x/c = x*(1/c) as far as lattice
1051     //       construction is concerned).
1052     {
1053       const ExprId e0 = expr.children.e0;
1054       const ExprId e1 = expr.children.e1;
1055       assert(!maybeZero(e1));
1056       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1057     }
1058   case TensorExp::Kind::kAddF:
1059   case TensorExp::Kind::kAddC:
1060   case TensorExp::Kind::kAddI:
1061   case TensorExp::Kind::kSubF:
1062   case TensorExp::Kind::kSubC:
1063   case TensorExp::Kind::kSubI:
1064   case TensorExp::Kind::kOrI:
1065   case TensorExp::Kind::kXorI:
1066     // An additive operation needs to be performed
1067     // for the disjunction of sparse iteration spaces.
1068     //
1069     //  x+y|!y | y |    x-y|!y | y |
1070     //  ---+---+---+    ---+---+---+
1071     //  !x | 0 | y |    !x | 0 |-y |
1072     //   x | x |x+y|     x | x |x-y|
1073     {
1074       const ExprId e0 = expr.children.e0;
1075       const ExprId e1 = expr.children.e1;
1076       return disjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1077     }
1078   case TensorExp::Kind::kCmpF:
1079   case TensorExp::Kind::kCmpI:
1080     // An comparison operation needs to be performed
1081     // for the disjunction of sparse iteration spaces.
1082     //
1083     //   x < y |  !y   |   y   |
1084     //  -------+-------+-------+
1085     //     !x  |   0   | 0 < y |
1086     //      x  | x < 0 | x < y |
1087     {
1088       const ExprId e0 = expr.children.e0;
1089       const ExprId e1 = expr.children.e1;
1090       return disjSetWithZero(e, buildLattices(e0, i), buildLattices(e1, i));
1091     }
1092   case TensorExp::Kind::kShrS:
1093   case TensorExp::Kind::kShrU:
1094   case TensorExp::Kind::kShlI:
1095     // A shift operation by an invariant amount (viz. tensor expressions
1096     // can only occur at the left-hand-side of the operator) can be handled
1097     // with the conjuction rule.
1098     {
1099       const ExprId e0 = expr.children.e0;
1100       const ExprId e1 = expr.children.e1;
1101       assert(isInvariant(e1));
1102       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1103     }
1104   case TensorExp::Kind::kBinary:
1105     // A custom binary operation.
1106     //
1107     //  x op y|   !y    |       y      |
1108     //  ------+---------+--------------+
1109     //    !x  |  empty  |   right(y)   |
1110     //     x  | left(x) | overlap(x,y) |
1111     {
1112       const ExprId e0 = expr.children.e0;
1113       const ExprId e1 = expr.children.e1;
1114       BinaryOp binop = cast<BinaryOp>(expr.op);
1115       const LatSetId child0 = buildLattices(e0, i);
1116       const LatSetId child1 = buildLattices(e1, i);
1117       Region &leftRegion = binop.getLeftRegion();
1118       Region &rightRegion = binop.getRightRegion();
1119       // Left Region.
1120       Operation *leftYield = nullptr;
1121       if (!leftRegion.empty()) {
1122         Block &leftBlock = leftRegion.front();
1123         leftYield = leftBlock.getTerminator();
1124       }
1125       // Right Region.
1126       Operation *rightYield = nullptr;
1127       if (!rightRegion.empty()) {
1128         Block &rightBlock = rightRegion.front();
1129         rightYield = rightBlock.getTerminator();
1130       }
1131       bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
1132       bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
1133       return combiSet(e, child0, child1, binop, includeLeft,
1134                       TensorExp::Kind::kBinaryBranch, leftYield, includeRight,
1135                       TensorExp::Kind::kBinaryBranch, rightYield);
1136     }
1137   case TensorExp::Kind::kReduce:
1138     // A custom reduce operation.
1139     {
1140       const ExprId e0 = expr.children.e0;
1141       const ExprId e1 = expr.children.e1;
1142       Operation *const op = expr.op;
1143       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1144     }
1145   }
1146   llvm_unreachable("unexpected expression kind");
1147 }
1148 
1149 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
1150   // Build the linalg semantics backward from yield.
1151   Operation *yield = op.getRegion().front().getTerminator();
1152   assert(isa<linalg::YieldOp>(yield));
1153   return buildTensorExp(op, yield->getOperand(0));
1154 }
1155 
1156 /// Only returns false if we are certain this is a nonzero.
1157 bool Merger::maybeZero(ExprId e) const {
1158   const auto &expr = exp(e);
1159   if (expr.kind == TensorExp::Kind::kInvariant) {
1160     if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1161       ArrayAttr arrayAttr = c.getValue();
1162       return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1163              cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1164     }
1165     if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
1166       return c.value() == 0;
1167     if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>())
1168       return c.value().isZero();
1169   }
1170   return true;
1171 }
1172 
1173 Type Merger::inferType(ExprId e, Value src) const {
1174   // Obtain the destination type from the cast node.
1175   Type dtp = exp(e).val.getType();
1176   // Inspect source type. For vector types, apply the same
1177   // vectorization to the destination type.
1178   if (auto vtp = dyn_cast<VectorType>(src.getType()))
1179     return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
1180   return dtp;
1181 }
1182 
1183 /// Ensures that sparse compiler can generate code for expression.
1184 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
1185   // Arguments are always admissible.
1186   if (isa<BlockArgument>(v))
1187     return true;
1188   // Accept index anywhere.
1189   Operation *def = v.getDefiningOp();
1190   if (isa<linalg::IndexOp>(def))
1191     return true;
1192   // Operation defined outside branch.
1193   if (def->getBlock() != block)
1194     return def->getBlock() != op->getBlock(); // invariant?
1195   // Operation defined within branch. Anything is accepted,
1196   // as long as all subexpressions are admissible.
1197   for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
1198     if (!isAdmissibleBranchExp(op, block, def->getOperand(i)))
1199       return false;
1200   return true;
1201 }
1202 
1203 /// Ensures that sparse compiler can generate code for branch.
1204 static bool isAdmissibleBranch(Operation *op, Region &region) {
1205   if (region.empty())
1206     return true;
1207   // Build the semi-ring branch semantics backward from yield.
1208   Operation *yield = region.front().getTerminator();
1209   assert(isa<YieldOp>(yield));
1210   return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
1211 }
1212 
1213 std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1214   if (auto arg = dyn_cast<BlockArgument>(v)) {
1215     const TensorId tid = makeTensorId(arg.getArgNumber());
1216     // Any argument of the generic op that is not marked as a scalar
1217     // argument is considered a tensor, indexed by the implicit loop
1218     // bounds. This includes rank-0 tensor arguments.
1219     if (arg.getOwner()->getParentOp() == op) {
1220       OpOperand &t = op->getOpOperand(tid);
1221       if (!op.isScalar(&t))
1222         return addTensorExp(tid);
1223       v = t.get(); // get scalar value
1224     }
1225     // Any other argument (marked as scalar argument for the generic op
1226     // or belonging to an enveloping op) is considered invariant.
1227     return addInvariantExp(v);
1228   }
1229   // Something defined outside is invariant.
1230   Operation *def = v.getDefiningOp();
1231   if (def->getBlock() != &op.getRegion().front())
1232     return addInvariantExp(v);
1233   // Construct index operations.
1234   if (def->getNumOperands() == 0) {
1235     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
1236       return addLoopVarExp(makeLoopId(indexOp.getDim()));
1237   }
1238   // Construct unary operations if subexpression can be built.
1239   if (def->getNumOperands() == 1) {
1240     const auto x = buildTensorExp(op, def->getOperand(0));
1241     if (x.has_value()) {
1242       const ExprId e = *x;
1243       if (isa<math::AbsFOp>(def))
1244         return addExp(TensorExp::Kind::kAbsF, e);
1245       if (isa<complex::AbsOp>(def))
1246         return addExp(TensorExp::Kind::kAbsC, e);
1247       if (isa<math::AbsIOp>(def))
1248         return addExp(TensorExp::Kind::kAbsI, e);
1249       if (isa<math::CeilOp>(def))
1250         return addExp(TensorExp::Kind::kCeilF, e);
1251       if (isa<math::FloorOp>(def))
1252         return addExp(TensorExp::Kind::kFloorF, e);
1253       if (isa<math::SqrtOp>(def))
1254         return addExp(TensorExp::Kind::kSqrtF, e);
1255       if (isa<complex::SqrtOp>(def))
1256         return addExp(TensorExp::Kind::kSqrtC, e);
1257       if (isa<math::ExpM1Op>(def))
1258         return addExp(TensorExp::Kind::kExpm1F, e);
1259       if (isa<complex::Expm1Op>(def))
1260         return addExp(TensorExp::Kind::kExpm1C, e);
1261       if (isa<math::Log1pOp>(def))
1262         return addExp(TensorExp::Kind::kLog1pF, e);
1263       if (isa<complex::Log1pOp>(def))
1264         return addExp(TensorExp::Kind::kLog1pC, e);
1265       if (isa<math::SinOp>(def))
1266         return addExp(TensorExp::Kind::kSinF, e);
1267       if (isa<complex::SinOp>(def))
1268         return addExp(TensorExp::Kind::kSinC, e);
1269       if (isa<math::TanhOp>(def))
1270         return addExp(TensorExp::Kind::kTanhF, e);
1271       if (isa<complex::TanhOp>(def))
1272         return addExp(TensorExp::Kind::kTanhC, e);
1273       if (isa<arith::NegFOp>(def))
1274         return addExp(TensorExp::Kind::kNegF, e); // no negi in std
1275       if (isa<complex::NegOp>(def))
1276         return addExp(TensorExp::Kind::kNegC, e);
1277       if (isa<arith::TruncFOp>(def))
1278         return addExp(TensorExp::Kind::kTruncF, e, v);
1279       if (isa<arith::ExtFOp>(def))
1280         return addExp(TensorExp::Kind::kExtF, e, v);
1281       if (isa<arith::FPToSIOp>(def))
1282         return addExp(TensorExp::Kind::kCastFS, e, v);
1283       if (isa<arith::FPToUIOp>(def))
1284         return addExp(TensorExp::Kind::kCastFU, e, v);
1285       if (isa<arith::SIToFPOp>(def))
1286         return addExp(TensorExp::Kind::kCastSF, e, v);
1287       if (isa<arith::UIToFPOp>(def))
1288         return addExp(TensorExp::Kind::kCastUF, e, v);
1289       if (isa<arith::ExtSIOp>(def))
1290         return addExp(TensorExp::Kind::kCastS, e, v);
1291       if (isa<arith::ExtUIOp>(def))
1292         return addExp(TensorExp::Kind::kCastU, e, v);
1293       if (isa<arith::IndexCastOp>(def))
1294         return addExp(TensorExp::Kind::kCastIdx, e, v);
1295       if (isa<arith::TruncIOp>(def))
1296         return addExp(TensorExp::Kind::kTruncI, e, v);
1297       if (isa<complex::ImOp>(def))
1298         return addExp(TensorExp::Kind::kCIm, e);
1299       if (isa<complex::ReOp>(def))
1300         return addExp(TensorExp::Kind::kCRe, e);
1301       if (isa<arith::BitcastOp>(def))
1302         return addExp(TensorExp::Kind::kBitCast, e, v);
1303       if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1304         if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
1305             isAdmissibleBranch(unop, unop.getAbsentRegion()))
1306           return addExp(TensorExp::Kind::kUnary, e, Value(), def);
1307       }
1308       if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1309         if (isAdmissibleBranch(selop, selop.getRegion()))
1310           return addExp(TensorExp::Kind::kSelect, e, Value(), def);
1311       }
1312     }
1313   }
1314   // Construct binary operations if subexpressions can be built.
1315   // See buildLattices() for an explanation of rejecting certain
1316   // division and shift operations.
1317   if (def->getNumOperands() == 2) {
1318     const auto x = buildTensorExp(op, def->getOperand(0));
1319     const auto y = buildTensorExp(op, def->getOperand(1));
1320     if (x.has_value() && y.has_value()) {
1321       const ExprId e0 = *x;
1322       const ExprId e1 = *y;
1323       if (isa<arith::MulFOp>(def))
1324         return addExp(TensorExp::Kind::kMulF, e0, e1);
1325       if (isa<complex::MulOp>(def))
1326         return addExp(TensorExp::Kind::kMulC, e0, e1);
1327       if (isa<arith::MulIOp>(def))
1328         return addExp(TensorExp::Kind::kMulI, e0, e1);
1329       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1330         return addExp(TensorExp::Kind::kDivF, e0, e1);
1331       if (isa<complex::DivOp>(def) && !maybeZero(e1))
1332         return addExp(TensorExp::Kind::kDivC, e0, e1);
1333       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1334         return addExp(TensorExp::Kind::kDivS, e0, e1);
1335       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1336         return addExp(TensorExp::Kind::kDivU, e0, e1);
1337       if (isa<arith::AddFOp>(def))
1338         return addExp(TensorExp::Kind::kAddF, e0, e1);
1339       if (isa<complex::AddOp>(def))
1340         return addExp(TensorExp::Kind::kAddC, e0, e1);
1341       if (isa<arith::AddIOp>(def))
1342         return addExp(TensorExp::Kind::kAddI, e0, e1);
1343       if (isa<arith::SubFOp>(def))
1344         return addExp(TensorExp::Kind::kSubF, e0, e1);
1345       if (isa<complex::SubOp>(def))
1346         return addExp(TensorExp::Kind::kSubC, e0, e1);
1347       if (isa<arith::SubIOp>(def))
1348         return addExp(TensorExp::Kind::kSubI, e0, e1);
1349       if (isa<arith::AndIOp>(def))
1350         return addExp(TensorExp::Kind::kAndI, e0, e1);
1351       if (isa<arith::OrIOp>(def))
1352         return addExp(TensorExp::Kind::kOrI, e0, e1);
1353       if (isa<arith::XOrIOp>(def))
1354         return addExp(TensorExp::Kind::kXorI, e0, e1);
1355       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1356         return addExp(TensorExp::Kind::kShrS, e0, e1);
1357       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1358         return addExp(TensorExp::Kind::kShrU, e0, e1);
1359       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1360         return addExp(TensorExp::Kind::kShlI, e0, e1);
1361       if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
1362         if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1363             ci.getPredicate() == arith::CmpIPredicate::sle &&
1364             ci.getPredicate() == arith::CmpIPredicate::sge &&
1365             ci.getPredicate() == arith::CmpIPredicate::ule &&
1366             ci.getPredicate() == arith::CmpIPredicate::uge) {
1367           // We can not sparsify comparison with equal, this is because 0 <= 0
1368           // yields true, and thus densifies the result.
1369           return std::nullopt;
1370         }
1371 
1372         return addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
1373                       ci.getPredicateAttr());
1374       }
1375       if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
1376         if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1377             cf.getPredicate() == arith::CmpFPredicate::OGE &&
1378             cf.getPredicate() == arith::CmpFPredicate::OLE &&
1379             cf.getPredicate() == arith::CmpFPredicate::ONE &&
1380             cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1381             cf.getPredicate() == arith::CmpFPredicate::UGE &&
1382             cf.getPredicate() == arith::CmpFPredicate::ULE &&
1383             cf.getPredicate() == arith::CmpFPredicate::ORD &&
1384             cf.getPredicate() == arith::CmpFPredicate::UNO) {
1385           // We can not sparsify comparison with equal, this is because 0 <= 0
1386           // yields true, and thus densifies the result.
1387           return std::nullopt;
1388         }
1389         return addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
1390                       cf.getPredicateAttr());
1391       }
1392       if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1393         if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
1394             (binop.getLeftIdentity() ||
1395              isAdmissibleBranch(binop, binop.getLeftRegion())) &&
1396             (binop.getRightIdentity() ||
1397              isAdmissibleBranch(binop, binop.getRightRegion())))
1398           return addExp(TensorExp::Kind::kBinary, e0, e1, def);
1399       }
1400     }
1401   }
1402   // Construct ternary operations if subexpressions can be built.
1403   if (def->getNumOperands() == 3) {
1404     const auto x = buildTensorExp(op, def->getOperand(0));
1405     const auto y = buildTensorExp(op, def->getOperand(1));
1406     const auto z = buildTensorExp(op, def->getOperand(2));
1407     if (x.has_value() && y.has_value() && z.has_value()) {
1408       const ExprId e0 = *x;
1409       const ExprId e1 = *y;
1410       if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1411         if (isAdmissibleBranch(redop, redop.getRegion()))
1412           return addExp(TensorExp::Kind::kReduce, e0, e1, def);
1413       }
1414     }
1415   }
1416   // Cannot build.
1417   return std::nullopt;
1418 }
1419 
1420 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1421                            ValueRange vals) {
1422   // Make a clone of overlap region.
1423   Region tmpRegion;
1424   IRMapping mapper;
1425   region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1426   Block &clonedBlock = tmpRegion.front();
1427   YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1428   // Merge cloned block and return yield value.
1429   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1430   rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
1431   Value val = clonedYield.getResult();
1432   rewriter.eraseOp(clonedYield);
1433   rewriter.eraseOp(placeholder);
1434   return val;
1435 }
1436 
1437 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
1438                                Operation *op, Value v0) {
1439   if (!v0)
1440     // Empty input value must be propagated.
1441     return Value();
1442   UnaryOp unop = cast<UnaryOp>(op);
1443   Region &presentRegion = unop.getPresentRegion();
1444   if (presentRegion.empty())
1445     // Uninitialized Value() will be interpreted as missing data in the
1446     // output.
1447     return Value();
1448   return insertYieldOp(rewriter, loc, presentRegion, {v0});
1449 }
1450 
1451 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
1452                                 Operation *op, Value v0, Value v1) {
1453   if (!v0 || !v1)
1454     // Empty input values must be propagated.
1455     return Value();
1456   BinaryOp binop = cast<BinaryOp>(op);
1457   Region &overlapRegion = binop.getOverlapRegion();
1458   if (overlapRegion.empty())
1459     // Uninitialized Value() will be interpreted as missing data in the
1460     // output.
1461     return Value();
1462   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1463 }
1464 
1465 Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
1466                        Value v1) const {
1467   const auto &expr = exp(e);
1468   switch (expr.kind) {
1469   // Leaf.
1470   case TensorExp::Kind::kTensor:
1471   case TensorExp::Kind::kInvariant:
1472   case TensorExp::Kind::kLoopVar:
1473   case TensorExp::Kind::kSynZero:
1474     llvm_unreachable("unexpected non-op");
1475   // Unary operations.
1476   case TensorExp::Kind::kAbsF:
1477     return rewriter.create<math::AbsFOp>(loc, v0);
1478   case TensorExp::Kind::kAbsC: {
1479     auto type = cast<ComplexType>(v0.getType());
1480     auto eltType = cast<FloatType>(type.getElementType());
1481     return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1482   }
1483   case TensorExp::Kind::kAbsI:
1484     return rewriter.create<math::AbsIOp>(loc, v0);
1485   case TensorExp::Kind::kCeilF:
1486     return rewriter.create<math::CeilOp>(loc, v0);
1487   case TensorExp::Kind::kFloorF:
1488     return rewriter.create<math::FloorOp>(loc, v0);
1489   case TensorExp::Kind::kSqrtF:
1490     return rewriter.create<math::SqrtOp>(loc, v0);
1491   case TensorExp::Kind::kSqrtC:
1492     return rewriter.create<complex::SqrtOp>(loc, v0);
1493   case TensorExp::Kind::kExpm1F:
1494     return rewriter.create<math::ExpM1Op>(loc, v0);
1495   case TensorExp::Kind::kExpm1C:
1496     return rewriter.create<complex::Expm1Op>(loc, v0);
1497   case TensorExp::Kind::kLog1pF:
1498     return rewriter.create<math::Log1pOp>(loc, v0);
1499   case TensorExp::Kind::kLog1pC:
1500     return rewriter.create<complex::Log1pOp>(loc, v0);
1501   case TensorExp::Kind::kSinF:
1502     return rewriter.create<math::SinOp>(loc, v0);
1503   case TensorExp::Kind::kSinC:
1504     return rewriter.create<complex::SinOp>(loc, v0);
1505   case TensorExp::Kind::kTanhF:
1506     return rewriter.create<math::TanhOp>(loc, v0);
1507   case TensorExp::Kind::kTanhC:
1508     return rewriter.create<complex::TanhOp>(loc, v0);
1509   case TensorExp::Kind::kNegF:
1510     return rewriter.create<arith::NegFOp>(loc, v0);
1511   case TensorExp::Kind::kNegC:
1512     return rewriter.create<complex::NegOp>(loc, v0);
1513   case TensorExp::Kind::kNegI: // no negi in std
1514     return rewriter.create<arith::SubIOp>(
1515         loc,
1516         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1517                                            rewriter.getZeroAttr(v0.getType())),
1518         v0);
1519   case TensorExp::Kind::kTruncF:
1520     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1521   case TensorExp::Kind::kExtF:
1522     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1523   case TensorExp::Kind::kCastFS:
1524     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1525   case TensorExp::Kind::kCastFU:
1526     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1527   case TensorExp::Kind::kCastSF:
1528     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1529   case TensorExp::Kind::kCastUF:
1530     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1531   case TensorExp::Kind::kCastS:
1532     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1533   case TensorExp::Kind::kCastU:
1534     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1535   case TensorExp::Kind::kCastIdx:
1536     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1537   case TensorExp::Kind::kTruncI:
1538     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1539   case TensorExp::Kind::kCIm: {
1540     auto type = cast<ComplexType>(v0.getType());
1541     auto eltType = cast<FloatType>(type.getElementType());
1542     return rewriter.create<complex::ImOp>(loc, eltType, v0);
1543   }
1544   case TensorExp::Kind::kCRe: {
1545     auto type = cast<ComplexType>(v0.getType());
1546     auto eltType = cast<FloatType>(type.getElementType());
1547     return rewriter.create<complex::ReOp>(loc, eltType, v0);
1548   }
1549   case TensorExp::Kind::kBitCast:
1550     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1551   // Binary operations.
1552   case TensorExp::Kind::kMulF:
1553     return rewriter.create<arith::MulFOp>(loc, v0, v1);
1554   case TensorExp::Kind::kMulC:
1555     return rewriter.create<complex::MulOp>(loc, v0, v1);
1556   case TensorExp::Kind::kMulI:
1557     return rewriter.create<arith::MulIOp>(loc, v0, v1);
1558   case TensorExp::Kind::kDivF:
1559     return rewriter.create<arith::DivFOp>(loc, v0, v1);
1560   case TensorExp::Kind::kDivC:
1561     return rewriter.create<complex::DivOp>(loc, v0, v1);
1562   case TensorExp::Kind::kDivS:
1563     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1564   case TensorExp::Kind::kDivU:
1565     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1566   case TensorExp::Kind::kAddF:
1567     return rewriter.create<arith::AddFOp>(loc, v0, v1);
1568   case TensorExp::Kind::kAddC:
1569     return rewriter.create<complex::AddOp>(loc, v0, v1);
1570   case TensorExp::Kind::kAddI:
1571     return rewriter.create<arith::AddIOp>(loc, v0, v1);
1572   case TensorExp::Kind::kSubF:
1573     return rewriter.create<arith::SubFOp>(loc, v0, v1);
1574   case TensorExp::Kind::kSubC:
1575     return rewriter.create<complex::SubOp>(loc, v0, v1);
1576   case TensorExp::Kind::kSubI:
1577     return rewriter.create<arith::SubIOp>(loc, v0, v1);
1578   case TensorExp::Kind::kAndI:
1579     return rewriter.create<arith::AndIOp>(loc, v0, v1);
1580   case TensorExp::Kind::kOrI:
1581     return rewriter.create<arith::OrIOp>(loc, v0, v1);
1582   case TensorExp::Kind::kXorI:
1583     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1584   case TensorExp::Kind::kShrS:
1585     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1586   case TensorExp::Kind::kShrU:
1587     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1588   case TensorExp::Kind::kShlI:
1589     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1590   case TensorExp::Kind::kCmpI: {
1591     auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr);
1592     return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1);
1593   }
1594   case TensorExp::Kind::kCmpF: {
1595     auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr);
1596     return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1);
1597   }
1598   case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
1599     return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(),
1600                          {v0});
1601   case TensorExp::Kind::kUnary:
1602     return buildUnaryPresent(rewriter, loc, expr.op, v0);
1603   case TensorExp::Kind::kSelect:
1604     return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
1605                          {v0});
1606   case TensorExp::Kind::kBinary:
1607     return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
1608   case TensorExp::Kind::kReduce: {
1609     ReduceOp redOp = cast<ReduceOp>(expr.op);
1610     return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1611   }
1612   }
1613   llvm_unreachable("unexpected expression kind in build");
1614 }
1615 
1616 } // namespace sparse_tensor
1617 } // namespace mlir
1618