xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (revision 70e227a404e51f9248c7ad5d79953805b2afacb4)
1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 #include <optional>
18 
19 namespace mlir {
20 namespace sparse_tensor {
21 
22 enum class ExpArity {
23   kNullary,
24   kUnary,
25   kBinary,
26 };
27 
getExpArity(TensorExp::Kind k)28 static ExpArity getExpArity(TensorExp::Kind k) {
29   switch (k) {
30   // Leaf.
31   case TensorExp::Kind::kTensor:
32   case TensorExp::Kind::kInvariant:
33   case TensorExp::Kind::kLoopVar:
34   case TensorExp::Kind::kSynZero:
35     return ExpArity::kNullary;
36   case TensorExp::Kind::kAbsF:
37   case TensorExp::Kind::kAbsC:
38   case TensorExp::Kind::kAbsI:
39   case TensorExp::Kind::kCeilF:
40   case TensorExp::Kind::kFloorF:
41   case TensorExp::Kind::kSqrtF:
42   case TensorExp::Kind::kSqrtC:
43   case TensorExp::Kind::kExpm1F:
44   case TensorExp::Kind::kExpm1C:
45   case TensorExp::Kind::kLog1pF:
46   case TensorExp::Kind::kLog1pC:
47   case TensorExp::Kind::kRelu:
48   case TensorExp::Kind::kSinF:
49   case TensorExp::Kind::kSinC:
50   case TensorExp::Kind::kTanhF:
51   case TensorExp::Kind::kTanhC:
52   case TensorExp::Kind::kTruncF:
53   case TensorExp::Kind::kExtF:
54   case TensorExp::Kind::kCastFS:
55   case TensorExp::Kind::kCastFU:
56   case TensorExp::Kind::kCastSF:
57   case TensorExp::Kind::kCastUF:
58   case TensorExp::Kind::kCastS:
59   case TensorExp::Kind::kCastU:
60   case TensorExp::Kind::kCastIdx:
61   case TensorExp::Kind::kTruncI:
62   case TensorExp::Kind::kCIm:
63   case TensorExp::Kind::kCRe:
64   case TensorExp::Kind::kBitCast:
65   case TensorExp::Kind::kBinaryBranch:
66   case TensorExp::Kind::kUnary:
67   case TensorExp::Kind::kSelect:
68   case TensorExp::Kind::kNegF:
69   case TensorExp::Kind::kNegC:
70   case TensorExp::Kind::kNegI:
71     return ExpArity::kUnary;
72   // Binary operations.
73   case TensorExp::Kind::kDivF:
74   case TensorExp::Kind::kDivC:
75   case TensorExp::Kind::kDivS:
76   case TensorExp::Kind::kDivU:
77   case TensorExp::Kind::kShrS:
78   case TensorExp::Kind::kShrU:
79   case TensorExp::Kind::kShlI:
80   case TensorExp::Kind::kMulF:
81   case TensorExp::Kind::kMulC:
82   case TensorExp::Kind::kMulI:
83   case TensorExp::Kind::kAndI:
84   case TensorExp::Kind::kAddF:
85   case TensorExp::Kind::kAddC:
86   case TensorExp::Kind::kAddI:
87   case TensorExp::Kind::kOrI:
88   case TensorExp::Kind::kXorI:
89   case TensorExp::Kind::kBinary:
90   case TensorExp::Kind::kReduce:
91   case TensorExp::Kind::kSubF:
92   case TensorExp::Kind::kSubC:
93   case TensorExp::Kind::kSubI:
94   case TensorExp::Kind::kCmpF:
95   case TensorExp::Kind::kCmpI:
96   case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands
97     return ExpArity::kBinary;
98   }
99   llvm_unreachable("unexpected kind");
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // Constructors.
104 //===----------------------------------------------------------------------===//
105 
TensorExp(TensorExp::Kind k,unsigned x,ExprId y,Value v,Operation * o,Attribute a)106 TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
107                      Operation *o, Attribute a)
108     : kind(k), val(v), op(o), attr(a) {
109   switch (kind) {
110   // Leaf.
111   case TensorExp::Kind::kTensor:
112     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
113     tensor = x;
114     return;
115   case TensorExp::Kind::kSynZero:
116     assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o);
117     return;
118   case TensorExp::Kind::kInvariant:
119     assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o);
120     return;
121   case TensorExp::Kind::kLoopVar:
122     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
123     loop = x;
124     return;
125   // Unary operations.
126   case TensorExp::Kind::kAbsF:
127   case TensorExp::Kind::kAbsC:
128   case TensorExp::Kind::kAbsI:
129   case TensorExp::Kind::kCeilF:
130   case TensorExp::Kind::kFloorF:
131   case TensorExp::Kind::kSqrtF:
132   case TensorExp::Kind::kSqrtC:
133   case TensorExp::Kind::kExpm1F:
134   case TensorExp::Kind::kExpm1C:
135   case TensorExp::Kind::kLog1pF:
136   case TensorExp::Kind::kLog1pC:
137   case TensorExp::Kind::kRelu:
138   case TensorExp::Kind::kSinF:
139   case TensorExp::Kind::kSinC:
140   case TensorExp::Kind::kTanhF:
141   case TensorExp::Kind::kTanhC:
142   case TensorExp::Kind::kNegF:
143   case TensorExp::Kind::kNegC:
144   case TensorExp::Kind::kNegI:
145   case TensorExp::Kind::kCIm:
146   case TensorExp::Kind::kCRe:
147     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
148     children.e0 = x;
149     children.e1 = y;
150     return;
151   case TensorExp::Kind::kTruncF:
152   case TensorExp::Kind::kExtF:
153   case TensorExp::Kind::kCastFS:
154   case TensorExp::Kind::kCastFU:
155   case TensorExp::Kind::kCastSF:
156   case TensorExp::Kind::kCastUF:
157   case TensorExp::Kind::kCastS:
158   case TensorExp::Kind::kCastU:
159   case TensorExp::Kind::kCastIdx:
160   case TensorExp::Kind::kTruncI:
161   case TensorExp::Kind::kBitCast:
162     assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o);
163     children.e0 = x;
164     children.e1 = y;
165     return;
166   case TensorExp::Kind::kBinaryBranch:
167   case TensorExp::Kind::kSelect:
168     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o);
169     children.e0 = x;
170     children.e1 = y;
171     return;
172   case TensorExp::Kind::kUnary:
173     // No assertion on y can be made, as the branching paths involve both
174     // a unary (`mapSet`) and binary (`disjSet`) pathway.
175     assert(x != detail::kInvalidId && !v && o);
176     children.e0 = x;
177     children.e1 = y;
178     return;
179   // Binary operations.
180   case TensorExp::Kind::kMulF:
181   case TensorExp::Kind::kMulC:
182   case TensorExp::Kind::kMulI:
183   case TensorExp::Kind::kDivF:
184   case TensorExp::Kind::kDivC:
185   case TensorExp::Kind::kDivS:
186   case TensorExp::Kind::kDivU:
187   case TensorExp::Kind::kAddF:
188   case TensorExp::Kind::kAddC:
189   case TensorExp::Kind::kAddI:
190   case TensorExp::Kind::kSubF:
191   case TensorExp::Kind::kSubC:
192   case TensorExp::Kind::kSubI:
193   case TensorExp::Kind::kAndI:
194   case TensorExp::Kind::kOrI:
195   case TensorExp::Kind::kXorI:
196   case TensorExp::Kind::kShrS:
197   case TensorExp::Kind::kShrU:
198   case TensorExp::Kind::kShlI:
199     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
200     children.e0 = x;
201     children.e1 = y;
202     return;
203   case TensorExp::Kind::kCmpF:
204   case TensorExp::Kind::kCmpI:
205     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
206     children.e0 = x;
207     children.e1 = y;
208     return;
209   case TensorExp::Kind::kBinary:
210   case TensorExp::Kind::kReduce:
211     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o);
212     children.e0 = x;
213     children.e1 = y;
214     return;
215   case TensorExp::Kind::kDenseOp:
216     assert(x != detail::kInvalidId && !v && o);
217     children.e0 = x;
218     children.e1 = y;
219     return;
220   }
221   llvm_unreachable("unexpected kind");
222 }
223 
Merger(unsigned numInputOutputTensors,unsigned numLoops,unsigned maxLvlRank)224 Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
225                unsigned maxLvlRank)
226     : outTensor(numInputOutputTensors - 1),
227       syntheticTensor(numInputOutputTensors),
228       numTensors(numInputOutputTensors + 1), numLoops(numLoops),
229       hasSparseOut(false),
230       lvlTypes(numTensors,
231                std::vector<LevelType>(numLoops, LevelFormat::Undef)),
232       loopToLvl(numTensors,
233                 std::vector<std::optional<Level>>(numLoops, std::nullopt)),
234       lvlToLoop(numTensors,
235                 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
236       loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>(
237                                          numTensors, std::nullopt)),
238       levelToDependentLoop(numTensors,
239                            std::vector<std::vector<LoopCoeffPair>>(
240                                maxLvlRank, std::vector<LoopCoeffPair>())),
241       loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
242 
243 //===----------------------------------------------------------------------===//
244 // Lattice methods.
245 //===----------------------------------------------------------------------===//
246 
addTensorExp(TensorId t)247 ExprId Merger::addTensorExp(TensorId t) {
248   assert(isValidTensorId(t));
249   const ExprId eNew(tensorExps.size());
250   tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId,
251                           Value(), nullptr, nullptr);
252   return eNew;
253 }
254 
addLoopVarExp(LoopId i)255 ExprId Merger::addLoopVarExp(LoopId i) {
256   assert(isValidLoopId(i));
257   const ExprId eNew(tensorExps.size());
258   tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId,
259                           Value(), nullptr, nullptr);
260   return eNew;
261 }
262 
addInvariantExp(Value v)263 ExprId Merger::addInvariantExp(Value v) {
264   const ExprId eNew(tensorExps.size());
265   tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId,
266                           detail::kInvalidId, v, nullptr, nullptr);
267   return eNew;
268 }
269 
addSynZeroExp()270 ExprId Merger::addSynZeroExp() {
271   const ExprId eNew(tensorExps.size());
272   tensorExps.emplace_back(TensorExp::Kind::kSynZero, detail::kInvalidId,
273                           detail::kInvalidId, Value(), nullptr, nullptr);
274   return eNew;
275 }
276 
addExp(TensorExp::Kind k,ExprId e0,ExprId e1,Operation * op,Attribute attr)277 ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op,
278                       Attribute attr) {
279   assert(k > TensorExp::Kind::kLoopVar);
280   const ExprId eNew(tensorExps.size());
281   tensorExps.emplace_back(k, e0, e1, Value(), op, attr);
282   return eNew;
283 }
284 
addExp(TensorExp::Kind k,ExprId e,Value v,Operation * op,Attribute attr)285 ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op,
286                       Attribute attr) {
287   assert(k > TensorExp::Kind::kLoopVar);
288   const ExprId eNew(tensorExps.size());
289   tensorExps.emplace_back(k, e, detail::kInvalidId, v, op, attr);
290   return eNew;
291 }
292 
addLat(TensorId t,LoopId i,ExprId e)293 LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {
294   const LatPointId pNew(latPoints.size());
295   const unsigned size = numLoops * numTensors;
296   const TensorLoopId b = makeTensorLoopId(t, i);
297   latPoints.emplace_back(size, e);
298   latPoints[pNew].bits.set(b);
299   return pNew;
300 }
301 
addLat(const BitVector & bits,ExprId e)302 LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
303   assert(bits.size() == numLoops * numTensors);
304   const LatPointId pNew(latPoints.size());
305   latPoints.emplace_back(bits, e);
306   return pNew;
307 }
308 
addSet()309 LatSetId Merger::addSet() {
310   const LatSetId sNew(latSets.size());
311   latSets.emplace_back();
312   return sNew;
313 }
314 
conjLat(ExprId e,LatPointId p0,LatPointId p1,Operation * op)315 LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1,
316                            Operation *op) {
317   TensorExp::Kind kind = exp(e).kind;
318   Attribute attr = exp(e).attr;
319   const LatPointId pNew(latPoints.size());
320   const auto &point0 = lat(p0);
321   const auto &point1 = lat(p1);
322   BitVector bits(point0.bits);
323   bits |= point1.bits;
324   const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr);
325   latPoints.emplace_back(bits, ne);
326   return pNew;
327 }
328 
conjSet(ExprId e,LatSetId s0,LatSetId s1,Operation * op)329 LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
330   const LatSetId sNew = addSet();
331   auto &setNew = latSets[sNew];
332   for (const LatPointId p0 : set(s0))
333     for (const LatPointId p1 : set(s1))
334       setNew.push_back(conjLat(e, p0, p1, op));
335   return sNew;
336 }
337 
disjSet(ExprId e,LatSetId s0,LatSetId s1,Operation * op)338 LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
339   const LatSetId sNew = conjSet(e, s0, s1, op);
340   TensorExp::Kind kind = exp(e).kind;
341   // Followed by all in s0.
342   latSets[sNew].append(latSets[s0]);
343   // Map binary 0-y to unary -y.
344   // TODO: move this if-else logic into buildLattices
345   if (kind == TensorExp::Kind::kSubF)
346     s1 = mapSet(TensorExp::Kind::kNegF, s1);
347   else if (kind == TensorExp::Kind::kSubC)
348     s1 = mapSet(TensorExp::Kind::kNegC, s1);
349   else if (kind == TensorExp::Kind::kSubI)
350     s1 = mapSet(TensorExp::Kind::kNegI, s1);
351   // Followed by all in s1.
352   latSets[sNew].append(latSets[s1]);
353   return sNew;
354 }
355 
disjSetWithZero(ExprId e,LatSetId s0,LatSetId s1)356 LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) {
357   assert(exp(e).kind == TensorExp::Kind::kCmpI ||
358          exp(e).kind == TensorExp::Kind::kCmpF);
359   const LatSetId sNew = conjSet(e, s0, s1, nullptr);
360 
361   ExprId e0 = exp(e).children.e0;
362   ExprId e1 = exp(e).children.e1;
363   if (exp(e0).kind == TensorExp::Kind::kSynZero ||
364       exp(e1).kind == TensorExp::Kind::kSynZero) {
365     // lhs and rhs can't be synthetic zero at the same time.
366     assert(exp(e0).kind != exp(e1).kind);
367     // If one of the operands has already been assigned to zero (the
368     // element is absent in the corresponding operand), then we do not
369     // need to build disjunctive set for it.
370     return sNew;
371   }
372 
373   auto lhsSet = mapBinWithSynZeroSet(e, s0, false);
374   auto rhsSet = mapBinWithSynZeroSet(e, s1, true);
375   latSets[sNew].append(latSets[lhsSet]);
376   latSets[sNew].append(latSets[rhsSet]);
377   return sNew;
378 }
379 
combiSet(ExprId e,LatSetId s0,LatSetId s1,Operation * orig,bool includeLeft,TensorExp::Kind ltrans,Operation * opleft,bool includeRight,TensorExp::Kind rtrans,Operation * opright)380 LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
381                           bool includeLeft, TensorExp::Kind ltrans,
382                           Operation *opleft, bool includeRight,
383                           TensorExp::Kind rtrans, Operation *opright) {
384   Attribute a = exp(e).attr;
385   const LatSetId sNew = conjSet(e, s0, s1, orig);
386   // Left Region.
387   if (includeLeft) {
388     if (opleft)
389       s0 = mapSet(ltrans, s0, Value(), opleft, a);
390     latSets[sNew].append(latSets[s0]);
391   }
392   // Right Region.
393   if (includeRight) {
394     if (opright)
395       s1 = mapSet(rtrans, s1, Value(), opright, a);
396     latSets[sNew].append(latSets[s1]);
397   }
398   return sNew;
399 }
400 
mapSet(TensorExp::Kind kind,LatSetId s0,Value v,Operation * op,Attribute a)401 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
402                         Operation *op, Attribute a) {
403   assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
404          TensorExp::Kind::kDenseOp == kind);
405   const LatSetId sNew = addSet();
406   auto &setNew = latSets[sNew];
407   for (const LatPointId p : set(s0)) {
408     const auto &point = latPoints[p];
409     setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op, a)));
410   }
411   return sNew;
412 }
413 
mapBinWithSynZeroSet(ExprId e,LatSetId s0,bool lhsZero)414 LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) {
415   TensorExp::Kind kind = exp(e).kind;
416   Attribute a = exp(e).attr;
417   assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI);
418   // Must be a binary operation.
419   const LatSetId sNew = addSet();
420   auto &setNew = latSets[sNew];
421   const ExprId zeroExp = addSynZeroExp();
422   for (const LatPointId p : set(s0)) {
423     const auto &point = latPoints[p];
424     ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a)
425                             : addExp(kind, point.exp, zeroExp, nullptr, a);
426     setNew.push_back(addLat(point.bits, newExp));
427   }
428   return sNew;
429 }
430 
optimizeSet(LatSetId s0)431 LatSetId Merger::optimizeSet(LatSetId s0) {
432   const LatSetId sNew = addSet();
433   auto &setNew = latSets[sNew];
434   const auto &set0 = set(s0);
435   assert(!set0.empty());
436   const LatPointId p0 = set0[0];
437   for (const LatPointId p1 : set0) {
438     bool add = true;
439     if (p0 != p1) {
440       // Check whether this is a straightforward copy.
441       if (expIsTensor(latPoints[p1].exp, outTensor))
442         continue;
443       // Check whether this conjunction is already covered.
444       for (const LatPointId p2 : setNew) {
445         assert(!latGT(p1, p2)); // Lj => Li would be bad
446         if (onlyDenseDiff(p2, p1)) {
447           add = false;
448           break;
449         }
450       }
451       assert(!add || latGT(p0, p1));
452     }
453     if (add)
454       setNew.push_back(p1);
455   }
456   for (const LatPointId p : setNew)
457     latPoints[p].simple = simplifyCond(sNew, p);
458   return sNew;
459 }
460 
simplifyCond(LatSetId s0,LatPointId p0)461 BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
462   // First determine if this lattice point is a *singleton*, i.e.,
463   // the last point in a lattice, no other is less than this one.
464   bool isSingleton = true;
465   for (const LatPointId p1 : set(s0)) {
466     if (p0 != p1 && latGT(p0, p1)) {
467       isSingleton = false;
468       break;
469     }
470   }
471 
472   BitVector simple(latPoints[p0].bits);
473   bool reset = isSingleton && hasAnySparse(simple);
474   const TensorLoopId be = simple.size();
475   TensorLoopId offset = 0; // relative to the end
476   if (!reset)
477     // Starts resetting from a dense level, so that the first bit (if kept)
478     // is not undefined level-type.
479     for (unsigned b = 0; b < be; b++) {
480       if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
481         offset = be - b - 1; // relative to the end
482         break;
483       }
484     }
485 
486   // Now apply the two basic rules. We also iterate the bits reversely to always
487   // keep the rightmost bit (which could possibly be a synthetic tensor).
488   for (unsigned b = be - 1 - offset, i = 0; i < be;
489        b = b == 0 ? be - 1 : b - 1, i++) {
490     // Slice on dense level has `locate` property as well, and can be optimized.
491     if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
492       const auto lt = getLvlType(b);
493       if (!lt.hasSparseSemantic()) {
494         if (reset)
495           simple.reset(b);
496         reset = true;
497       }
498     }
499   }
500   return simple;
501 }
502 
latGT(LatPointId i,LatPointId j) const503 bool Merger::latGT(LatPointId i, LatPointId j) const {
504   const BitVector &bitsi = lat(i).bits;
505   const BitVector &bitsj = lat(j).bits;
506   assert(bitsi.size() == bitsj.size());
507   if (bitsi.count() > bitsj.count()) {
508     for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
509       if (bitsj[b] && !bitsi[b])
510         return false;
511     return true;
512   }
513   return false;
514 }
515 
onlyDenseDiff(LatPointId i,LatPointId j) const516 bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
517   BitVector tmp(latPoints[j].bits);
518   tmp ^= latPoints[i].bits;
519   return !hasAnySparse(tmp);
520 }
521 
expContainsTensor(ExprId e,TensorId t) const522 bool Merger::expContainsTensor(ExprId e, TensorId t) const {
523   const auto &expr = exp(e);
524   // First we check `expIsTensor`.
525   if (expr.kind == TensorExp::Kind::kTensor)
526     return expr.tensor == t;
527 
528   switch (getExpArity(expr.kind)) {
529   case ExpArity::kNullary:
530     return false;
531   case ExpArity::kUnary: {
532     const ExprId e0 = expr.children.e0;
533     return expContainsTensor(e0, t);
534   }
535   case ExpArity::kBinary: {
536     const ExprId e0 = expr.children.e0;
537     const ExprId e1 = expr.children.e1;
538     return expContainsTensor(e0, t) || expContainsTensor(e1, t);
539   }
540   }
541   llvm_unreachable("unexpected arity");
542 }
543 
hasNegateOnOut(ExprId e) const544 bool Merger::hasNegateOnOut(ExprId e) const {
545   const auto &expr = exp(e);
546   switch (expr.kind) {
547   case TensorExp::Kind::kNegF:
548   case TensorExp::Kind::kNegC:
549   case TensorExp::Kind::kNegI:
550     return expContainsTensor(expr.children.e0, outTensor);
551   case TensorExp::Kind::kSubF:
552   case TensorExp::Kind::kSubC:
553   case TensorExp::Kind::kSubI:
554     return expContainsTensor(expr.children.e1, outTensor) ||
555            hasNegateOnOut(expr.children.e0);
556   case TensorExp::Kind::kDenseOp: {
557     bool lhsNeg = hasNegateOnOut(expr.children.e0);
558     if (!lhsNeg && expr.children.e1 != detail::kInvalidId)
559       return hasNegateOnOut(expr.children.e1);
560     return lhsNeg;
561   }
562   default: {
563     switch (getExpArity(expr.kind)) {
564     case ExpArity::kNullary:
565       return false;
566     case ExpArity::kUnary:
567       return hasNegateOnOut(expr.children.e0);
568     case ExpArity::kBinary:
569       return hasNegateOnOut(expr.children.e0) ||
570              hasNegateOnOut(expr.children.e1);
571     }
572   }
573   }
574   llvm_unreachable("unexpected kind");
575 }
576 
isSingleCondition(TensorId t,ExprId e) const577 bool Merger::isSingleCondition(TensorId t, ExprId e) const {
578   assert(isValidTensorId(t));
579   const auto &expr = exp(e);
580   switch (expr.kind) {
581   // Leaf.
582   case TensorExp::Kind::kTensor:
583     return expr.tensor == t;
584   case TensorExp::Kind::kInvariant:
585   case TensorExp::Kind::kLoopVar:
586   case TensorExp::Kind::kSynZero:
587     return false;
588   // Unary operations.
589   case TensorExp::Kind::kAbsF:
590   case TensorExp::Kind::kAbsC:
591   case TensorExp::Kind::kAbsI:
592   case TensorExp::Kind::kCeilF:
593   case TensorExp::Kind::kFloorF:
594   case TensorExp::Kind::kSqrtF:
595   case TensorExp::Kind::kSqrtC:
596   case TensorExp::Kind::kExpm1F:
597   case TensorExp::Kind::kExpm1C:
598   case TensorExp::Kind::kLog1pF:
599   case TensorExp::Kind::kLog1pC:
600   case TensorExp::Kind::kRelu:
601   case TensorExp::Kind::kSinF:
602   case TensorExp::Kind::kSinC:
603   case TensorExp::Kind::kTanhF:
604   case TensorExp::Kind::kTanhC:
605   case TensorExp::Kind::kNegF:
606   case TensorExp::Kind::kNegC:
607   case TensorExp::Kind::kNegI:
608   case TensorExp::Kind::kTruncF:
609   case TensorExp::Kind::kExtF:
610   case TensorExp::Kind::kCastFS:
611   case TensorExp::Kind::kCastFU:
612   case TensorExp::Kind::kCastSF:
613   case TensorExp::Kind::kCastUF:
614   case TensorExp::Kind::kCastS:
615   case TensorExp::Kind::kCastU:
616   case TensorExp::Kind::kCastIdx:
617   case TensorExp::Kind::kTruncI:
618   case TensorExp::Kind::kCIm:
619   case TensorExp::Kind::kCRe:
620   case TensorExp::Kind::kBitCast:
621   case TensorExp::Kind::kUnary:
622     return isSingleCondition(t, expr.children.e0);
623   case TensorExp::Kind::kBinaryBranch:
624   case TensorExp::Kind::kSelect:
625     return false;
626   // Binary operations.
627   case TensorExp::Kind::kDivF: // note: x / c only
628   case TensorExp::Kind::kDivC:
629   case TensorExp::Kind::kDivS:
630   case TensorExp::Kind::kDivU:
631     assert(!maybeZero(expr.children.e1));
632     return isSingleCondition(t, expr.children.e0);
633   case TensorExp::Kind::kShrS: // note: x >> inv only
634   case TensorExp::Kind::kShrU:
635   case TensorExp::Kind::kShlI:
636     assert(isInvariant(expr.children.e1));
637     return isSingleCondition(t, expr.children.e0);
638   case TensorExp::Kind::kMulF:
639   case TensorExp::Kind::kMulC:
640   case TensorExp::Kind::kMulI:
641   case TensorExp::Kind::kAndI:
642   case TensorExp::Kind::kReduce:
643     if (isSingleCondition(t, expr.children.e0))
644       return isSingleCondition(t, expr.children.e1) ||
645              isInvariant(expr.children.e1);
646     if (isSingleCondition(t, expr.children.e1))
647       return isInvariant(expr.children.e0);
648     return false;
649   case TensorExp::Kind::kAddF:
650   case TensorExp::Kind::kAddC:
651   case TensorExp::Kind::kAddI:
652     return isSingleCondition(t, expr.children.e0) &&
653            isSingleCondition(t, expr.children.e1);
654   case TensorExp::Kind::kSubF:
655   case TensorExp::Kind::kSubC:
656   case TensorExp::Kind::kSubI:
657   case TensorExp::Kind::kOrI:
658   case TensorExp::Kind::kXorI:
659   case TensorExp::Kind::kCmpF:
660   case TensorExp::Kind::kCmpI:
661   case TensorExp::Kind::kBinary:
662     return false;
663   case TensorExp::Kind::kDenseOp:
664     // Since Merger guarantees all the operands of the kDenseOp to be dense, the
665     // operation must be single-condition.
666     return true;
667   }
668   llvm_unreachable("unexpected kind");
669 }
670 
hasAnySparse(const BitVector & bits) const671 bool Merger::hasAnySparse(const BitVector &bits) const {
672   for (TensorLoopId b : bits.set_bits()) {
673     const auto lt = getLvlType(b);
674     if (lt.hasSparseSemantic())
675       return true;
676   }
677   return hasSparseIdxReduction(bits);
678 }
679 
hasSparseIdxReduction(const BitVector & bits) const680 bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
681   for (TensorLoopId b : bits.set_bits())
682     if (isSparseLvlWithNonTrivialIdxExp(b))
683       return true;
684   return false;
685 }
686 
687 #ifndef NDEBUG
688 
689 //===----------------------------------------------------------------------===//
690 // Print methods (for debugging).
691 //===----------------------------------------------------------------------===//
692 
kindToOpSymbol(TensorExp::Kind kind)693 static const char *kindToOpSymbol(TensorExp::Kind kind) {
694   switch (kind) {
695   // Leaf.
696   case TensorExp::Kind::kTensor:
697     return "tensor";
698   case TensorExp::Kind::kInvariant:
699     return "invariant";
700   case TensorExp::Kind::kLoopVar:
701     return "index";
702   case TensorExp::Kind::kSynZero:
703     return "0";
704   // Unary operations.
705   case TensorExp::Kind::kAbsF:
706   case TensorExp::Kind::kAbsC:
707   case TensorExp::Kind::kAbsI:
708     return "abs";
709   case TensorExp::Kind::kCeilF:
710     return "ceil";
711   case TensorExp::Kind::kFloorF:
712     return "floor";
713   case TensorExp::Kind::kSqrtF:
714   case TensorExp::Kind::kSqrtC:
715     return "sqrt";
716   case TensorExp::Kind::kExpm1F:
717   case TensorExp::Kind::kExpm1C:
718     return "expm1";
719   case TensorExp::Kind::kLog1pF:
720   case TensorExp::Kind::kLog1pC:
721     return "log1p";
722   case TensorExp::Kind::kRelu:
723     return "relu";
724   case TensorExp::Kind::kSinF:
725   case TensorExp::Kind::kSinC:
726     return "sin";
727   case TensorExp::Kind::kTanhF:
728   case TensorExp::Kind::kTanhC:
729     return "tanh";
730   case TensorExp::Kind::kNegF:
731   case TensorExp::Kind::kNegC:
732   case TensorExp::Kind::kNegI:
733     return "-";
734   case TensorExp::Kind::kTruncF:
735   case TensorExp::Kind::kExtF:
736   case TensorExp::Kind::kCastFS:
737   case TensorExp::Kind::kCastFU:
738   case TensorExp::Kind::kCastSF:
739   case TensorExp::Kind::kCastUF:
740   case TensorExp::Kind::kCastS:
741   case TensorExp::Kind::kCastU:
742   case TensorExp::Kind::kCastIdx:
743   case TensorExp::Kind::kTruncI:
744   case TensorExp::Kind::kCIm:
745     return "complex.im";
746   case TensorExp::Kind::kCRe:
747     return "complex.re";
748   case TensorExp::Kind::kBitCast:
749     return "cast";
750   case TensorExp::Kind::kBinaryBranch:
751     return "binary_branch";
752   case TensorExp::Kind::kUnary:
753     return "unary";
754   case TensorExp::Kind::kSelect:
755     return "select";
756   // Binary operations.
757   case TensorExp::Kind::kMulF:
758   case TensorExp::Kind::kMulC:
759   case TensorExp::Kind::kMulI:
760     return "*";
761   case TensorExp::Kind::kDivF:
762   case TensorExp::Kind::kDivC:
763   case TensorExp::Kind::kDivS:
764   case TensorExp::Kind::kDivU:
765     return "/";
766   case TensorExp::Kind::kAddF:
767   case TensorExp::Kind::kAddC:
768   case TensorExp::Kind::kAddI:
769     return "+";
770   case TensorExp::Kind::kSubF:
771   case TensorExp::Kind::kSubC:
772   case TensorExp::Kind::kSubI:
773     return "-";
774   case TensorExp::Kind::kAndI:
775     return "&";
776   case TensorExp::Kind::kOrI:
777     return "|";
778   case TensorExp::Kind::kXorI:
779     return "^";
780   case TensorExp::Kind::kShrS:
781     return "a>>";
782   case TensorExp::Kind::kShrU:
783     return ">>";
784   case TensorExp::Kind::kShlI:
785     return "<<";
786   case TensorExp::Kind::kCmpF:
787   case TensorExp::Kind::kCmpI:
788     return "cmp";
789   case TensorExp::Kind::kBinary:
790     return "binary";
791   case TensorExp::Kind::kReduce:
792     return "reduce";
793   case TensorExp::Kind::kDenseOp:
794     return "dense";
795   }
796   llvm_unreachable("unexpected kind for symbol");
797 }
798 
dumpExp(ExprId e) const799 void Merger::dumpExp(ExprId e) const {
800   const auto &expr = exp(e);
801   switch (expr.kind) {
802   // Leaf.
803   case TensorExp::Kind::kTensor:
804     if (expr.tensor == syntheticTensor)
805       llvm::dbgs() << "synthetic_";
806     else if (expr.tensor == outTensor)
807       llvm::dbgs() << "output_";
808     llvm::dbgs() << "tensor_" << expr.tensor;
809     break;
810   case TensorExp::Kind::kInvariant:
811     llvm::dbgs() << "invariant";
812     break;
813   case TensorExp::Kind::kSynZero:
814     llvm::dbgs() << "0";
815     break;
816   case TensorExp::Kind::kLoopVar:
817     llvm::dbgs() << "loopvar_" << expr.loop;
818     break;
819   // Unary operations.
820   case TensorExp::Kind::kAbsF:
821   case TensorExp::Kind::kAbsC:
822   case TensorExp::Kind::kAbsI:
823   case TensorExp::Kind::kCeilF:
824   case TensorExp::Kind::kFloorF:
825   case TensorExp::Kind::kSqrtF:
826   case TensorExp::Kind::kSqrtC:
827   case TensorExp::Kind::kExpm1F:
828   case TensorExp::Kind::kExpm1C:
829   case TensorExp::Kind::kLog1pF:
830   case TensorExp::Kind::kLog1pC:
831   case TensorExp::Kind::kRelu:
832   case TensorExp::Kind::kSinF:
833   case TensorExp::Kind::kSinC:
834   case TensorExp::Kind::kTanhF:
835   case TensorExp::Kind::kTanhC:
836   case TensorExp::Kind::kNegF:
837   case TensorExp::Kind::kNegC:
838   case TensorExp::Kind::kNegI:
839   case TensorExp::Kind::kTruncF:
840   case TensorExp::Kind::kExtF:
841   case TensorExp::Kind::kCastFS:
842   case TensorExp::Kind::kCastFU:
843   case TensorExp::Kind::kCastSF:
844   case TensorExp::Kind::kCastUF:
845   case TensorExp::Kind::kCastS:
846   case TensorExp::Kind::kCastU:
847   case TensorExp::Kind::kCastIdx:
848   case TensorExp::Kind::kTruncI:
849   case TensorExp::Kind::kCIm:
850   case TensorExp::Kind::kCRe:
851   case TensorExp::Kind::kBitCast:
852   case TensorExp::Kind::kBinaryBranch:
853   case TensorExp::Kind::kUnary:
854   case TensorExp::Kind::kSelect:
855     llvm::dbgs() << kindToOpSymbol(expr.kind) << " ";
856     dumpExp(expr.children.e0);
857     break;
858   // Binary operations.
859   case TensorExp::Kind::kMulF:
860   case TensorExp::Kind::kMulC:
861   case TensorExp::Kind::kMulI:
862   case TensorExp::Kind::kDivF:
863   case TensorExp::Kind::kDivC:
864   case TensorExp::Kind::kDivS:
865   case TensorExp::Kind::kDivU:
866   case TensorExp::Kind::kAddF:
867   case TensorExp::Kind::kAddC:
868   case TensorExp::Kind::kAddI:
869   case TensorExp::Kind::kSubF:
870   case TensorExp::Kind::kSubC:
871   case TensorExp::Kind::kSubI:
872   case TensorExp::Kind::kAndI:
873   case TensorExp::Kind::kOrI:
874   case TensorExp::Kind::kXorI:
875   case TensorExp::Kind::kShrS:
876   case TensorExp::Kind::kShrU:
877   case TensorExp::Kind::kShlI:
878   case TensorExp::Kind::kCmpF:
879   case TensorExp::Kind::kCmpI:
880   case TensorExp::Kind::kBinary:
881   case TensorExp::Kind::kReduce:
882   case TensorExp::Kind::kDenseOp:
883     llvm::dbgs() << "(";
884     dumpExp(expr.children.e0);
885     llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
886     if (expr.attr)
887       llvm::dbgs() << "{" << expr.attr << "}";
888     if (expr.children.e1 != detail::kInvalidId) {
889       llvm::dbgs() << " ";
890       dumpExp(expr.children.e1);
891       llvm::dbgs() << ")";
892     } else {
893       assert(expr.kind == TensorExp::Kind::kDenseOp);
894     }
895     break;
896   }
897 }
898 
dumpLat(LatPointId p) const899 void Merger::dumpLat(LatPointId p) const {
900   const auto &point = lat(p);
901   llvm::dbgs() << "lat(";
902   dumpBits(point.bits);
903   llvm::dbgs() << " :";
904   dumpBits(point.simple);
905   llvm::dbgs() << " : ";
906   dumpExp(point.exp);
907   llvm::dbgs() << " )\n";
908 }
909 
dumpSet(LatSetId s) const910 void Merger::dumpSet(LatSetId s) const {
911   const auto &ss = set(s);
912   llvm::dbgs() << "{ #" << ss.size() << "\n";
913   for (const LatPointId p : ss) {
914     llvm::dbgs() << "  ";
915     dumpLat(p);
916   }
917   llvm::dbgs() << "}\n";
918 }
919 
dumpBits(const BitVector & bits) const920 void Merger::dumpBits(const BitVector &bits) const {
921   for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
922     if (bits[b]) {
923       const TensorId t = tensor(b);
924       const LoopId i = loop(b);
925       const auto lt = lvlTypes[t][i];
926       if (isLvlWithNonTrivialIdxExp(b))
927         llvm::dbgs() << " DEP_" << t << "_" << i;
928       else
929         llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt);
930     }
931   }
932 }
933 
934 #endif // NDEBUG
935 
936 //===----------------------------------------------------------------------===//
937 // Builder methods.
938 //===----------------------------------------------------------------------===//
939 
buildLattices(ExprId e,LoopId i)940 LatSetId Merger::buildLattices(ExprId e, LoopId i) {
941   // NOTE: The `expr` reference will be invalidated by recursive calls
942   // (and any other method that may add new expressions); therefore, the
943   // code below must make sure to copy fields of `expr` into local variables
944   // before making any recursive calls.
945   const auto &expr = exp(e);
946   const TensorExp::Kind kind = expr.kind;
947   switch (kind) {
948   // Leaf.
949   case TensorExp::Kind::kTensor:
950   case TensorExp::Kind::kInvariant:
951   case TensorExp::Kind::kSynZero:
952   case TensorExp::Kind::kLoopVar: {
953     // Either the loop-var is really used in the tensor expression, or it is
954     // set to the undefined loop-var in that level. An invariant expression,
955     // a proper index value, and a truly dynamic sparse output tensor are set
956     // to a synthetic tensor with undefined indices only to ensure the
957     // iteration space is not skipped as a result of their contents.
958     const LatSetId s = addSet();
959     TensorId t = syntheticTensor;
960     if (kind == TensorExp::Kind::kTensor) {
961       t = expr.tensor;
962       if (hasSparseOut && t == outTensor)
963         t = syntheticTensor;
964     }
965     latSets[s].push_back(addLat(t, i, e));
966     return s;
967   }
968   // Unary operations.
969   case TensorExp::Kind::kAbsF:
970   case TensorExp::Kind::kAbsC:
971   case TensorExp::Kind::kAbsI:
972   case TensorExp::Kind::kCeilF:
973   case TensorExp::Kind::kFloorF:
974   case TensorExp::Kind::kSqrtF:
975   case TensorExp::Kind::kSqrtC:
976   case TensorExp::Kind::kExpm1F:
977   case TensorExp::Kind::kExpm1C:
978   case TensorExp::Kind::kLog1pF:
979   case TensorExp::Kind::kLog1pC:
980   case TensorExp::Kind::kRelu:
981   case TensorExp::Kind::kSinF:
982   case TensorExp::Kind::kSinC:
983   case TensorExp::Kind::kTanhF:
984   case TensorExp::Kind::kTanhC:
985   case TensorExp::Kind::kNegF:
986   case TensorExp::Kind::kNegC:
987   case TensorExp::Kind::kNegI:
988   case TensorExp::Kind::kTruncF:
989   case TensorExp::Kind::kExtF:
990   case TensorExp::Kind::kCastFS:
991   case TensorExp::Kind::kCastFU:
992   case TensorExp::Kind::kCastSF:
993   case TensorExp::Kind::kCastUF:
994   case TensorExp::Kind::kCastS:
995   case TensorExp::Kind::kCastU:
996   case TensorExp::Kind::kCastIdx:
997   case TensorExp::Kind::kTruncI:
998   case TensorExp::Kind::kCIm:
999   case TensorExp::Kind::kCRe:
1000   case TensorExp::Kind::kBitCast:
1001     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
1002     // lattice set of the operand through the operator into a new set.
1003     //
1004     //  -y|!y | y |
1005     //  --+---+---+
1006     //    | 0 |-y |
1007     {
1008       const ExprId e0 = expr.children.e0;
1009       const Value v = expr.val;
1010       Attribute a = expr.attr;
1011       return mapSet(kind, buildLattices(e0, i), v, nullptr, a);
1012     }
1013   case TensorExp::Kind::kBinaryBranch:
1014   case TensorExp::Kind::kSelect:
1015     // The left or right half of a binary operation which has already
1016     // been split into separate operations for each region.
1017     {
1018       const ExprId e0 = expr.children.e0;
1019       Operation *const op = expr.op;
1020       return mapSet(kind, buildLattices(e0, i), Value(), op);
1021     }
1022   case TensorExp::Kind::kUnary:
1023     // A custom unary operation.
1024     //
1025     //  op y|    !y    |     y      |
1026     //  ----+----------+------------+
1027     //      | absent() | present(y) |
1028     {
1029       const ExprId e0 = expr.children.e0;
1030       UnaryOp unop = cast<UnaryOp>(expr.op);
1031       const LatSetId child0 = buildLattices(e0, i);
1032       Region &absentRegion = unop.getAbsentRegion();
1033       if (absentRegion.empty()) {
1034         // Simple mapping over existing values.
1035         return mapSet(kind, child0, Value(), unop);
1036       }
1037       // Use a disjunction with `unop` on the left and the absent value as an
1038       // invariant on the right.
1039       Block &absentBlock = absentRegion.front();
1040       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
1041       const Value absentVal = absentYield.getSingleResult();
1042       const ExprId rhs = addInvariantExp(absentVal);
1043       return disjSet(e, child0, buildLattices(rhs, i), unop);
1044     }
1045   // Binary operations.
1046   case TensorExp::Kind::kMulF:
1047   case TensorExp::Kind::kMulC:
1048   case TensorExp::Kind::kMulI:
1049   case TensorExp::Kind::kAndI:
1050     // A multiplicative operation only needs to be performed
1051     // for the conjunction of sparse iteration spaces.
1052     //
1053     //  x*y|!y | y |
1054     //  ---+---+---+
1055     //  !x | 0 | 0 |
1056     //   x | 0 |x*y|
1057     //
1058     // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
1059     {
1060       const ExprId e0 = expr.children.e0;
1061       const ExprId e1 = expr.children.e1;
1062       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1063     }
1064   case TensorExp::Kind::kDivF:
1065   case TensorExp::Kind::kDivC:
1066   case TensorExp::Kind::kDivS:
1067   case TensorExp::Kind::kDivU:
1068     // A division is tricky, since 0/0, 0/c, c/0 all have
1069     // specific outcomes for floating-point and integers.
1070     // Thus, we need to traverse the full iteration space.
1071     //
1072     //  x/y|!y | y |
1073     //  ---+---+---+
1074     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
1075     //   x |x/0|x/y|  INT: x/0=exception for any x
1076     //
1077     // TODO: for now we "fixed" this by only accepting x/c cases
1078     //       during expression building, so that the conjunction
1079     //       rules applies (viz. x/c = x*(1/c) as far as lattice
1080     //       construction is concerned).
1081     {
1082       const ExprId e0 = expr.children.e0;
1083       const ExprId e1 = expr.children.e1;
1084       assert(!maybeZero(e1));
1085       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1086     }
1087   case TensorExp::Kind::kAddF:
1088   case TensorExp::Kind::kAddC:
1089   case TensorExp::Kind::kAddI:
1090   case TensorExp::Kind::kSubF:
1091   case TensorExp::Kind::kSubC:
1092   case TensorExp::Kind::kSubI:
1093   case TensorExp::Kind::kOrI:
1094   case TensorExp::Kind::kXorI:
1095     // An additive operation needs to be performed
1096     // for the disjunction of sparse iteration spaces.
1097     //
1098     //  x+y|!y | y |    x-y|!y | y |
1099     //  ---+---+---+    ---+---+---+
1100     //  !x | 0 | y |    !x | 0 |-y |
1101     //   x | x |x+y|     x | x |x-y|
1102     {
1103       const ExprId e0 = expr.children.e0;
1104       const ExprId e1 = expr.children.e1;
1105       return disjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1106     }
1107   case TensorExp::Kind::kCmpF:
1108   case TensorExp::Kind::kCmpI:
1109     // A comparison operation needs to be performed
1110     // for the disjunction of sparse iteration spaces.
1111     //
1112     //   x < y |  !y   |   y   |
1113     //  -------+-------+-------+
1114     //     !x  |   0   | 0 < y |
1115     //      x  | x < 0 | x < y |
1116     {
1117       const ExprId e0 = expr.children.e0;
1118       const ExprId e1 = expr.children.e1;
1119       return disjSetWithZero(e, buildLattices(e0, i), buildLattices(e1, i));
1120     }
1121   case TensorExp::Kind::kShrS:
1122   case TensorExp::Kind::kShrU:
1123   case TensorExp::Kind::kShlI:
1124     // A shift operation by an invariant amount (viz. tensor expressions
1125     // can only occur at the left-hand-side of the operator) can be handled
1126     // with the conjunction rule.
1127     {
1128       const ExprId e0 = expr.children.e0;
1129       const ExprId e1 = expr.children.e1;
1130       assert(isInvariant(e1));
1131       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1132     }
1133   case TensorExp::Kind::kBinary:
1134     // A custom binary operation.
1135     //
1136     //  x op y|   !y    |       y      |
1137     //  ------+---------+--------------+
1138     //    !x  |  empty  |   right(y)   |
1139     //     x  | left(x) | overlap(x,y) |
1140     {
1141       const ExprId e0 = expr.children.e0;
1142       const ExprId e1 = expr.children.e1;
1143       BinaryOp binop = cast<BinaryOp>(expr.op);
1144       const LatSetId child0 = buildLattices(e0, i);
1145       const LatSetId child1 = buildLattices(e1, i);
1146       Region &leftRegion = binop.getLeftRegion();
1147       Region &rightRegion = binop.getRightRegion();
1148       // Left Region.
1149       Operation *leftYield = nullptr;
1150       if (!leftRegion.empty()) {
1151         Block &leftBlock = leftRegion.front();
1152         leftYield = leftBlock.getTerminator();
1153       }
1154       // Right Region.
1155       Operation *rightYield = nullptr;
1156       if (!rightRegion.empty()) {
1157         Block &rightBlock = rightRegion.front();
1158         rightYield = rightBlock.getTerminator();
1159       }
1160       bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
1161       bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
1162       return combiSet(e, child0, child1, binop, includeLeft,
1163                       TensorExp::Kind::kBinaryBranch, leftYield, includeRight,
1164                       TensorExp::Kind::kBinaryBranch, rightYield);
1165     }
1166   case TensorExp::Kind::kReduce:
1167     // A custom reduce operation.
1168     {
1169       const ExprId e0 = expr.children.e0;
1170       const ExprId e1 = expr.children.e1;
1171       Operation *const op = expr.op;
1172       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1173     }
1174   case TensorExp::Kind::kDenseOp: {
1175     // It does not really matter whether we use conjunctive/disjunctive set
1176     // here, as all the operands of kDenseOp must be dense, the disjunctive set
1177     // will be optimized into conjunctive set eventually.
1178     if (expr.children.e1 == detail::kInvalidId) {
1179       const ExprId e0 = expr.children.e0;
1180       Operation *const op = expr.op;
1181       return mapSet(kind, buildLattices(e0, i), Value(), op);
1182     }
1183 
1184     const ExprId e0 = expr.children.e0;
1185     const ExprId e1 = expr.children.e1;
1186     Operation *const op = expr.op;
1187     return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1188   }
1189   }
1190   llvm_unreachable("unexpected expression kind");
1191 }
1192 
buildTensorExpFromLinalg(linalg::GenericOp op)1193 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
1194   // Build the linalg semantics backward from yield.
1195   Operation *yield = op.getRegion().front().getTerminator();
1196   assert(isa<linalg::YieldOp>(yield));
1197   return buildTensorExp(op, yield->getOperand(0)).first;
1198 }
1199 
1200 /// Only returns true if we are certain this is a zero.
isCertainZero(Value val)1201 static bool isCertainZero(Value val) {
1202   if (auto c = val.getDefiningOp<complex::ConstantOp>()) {
1203     ArrayAttr arrayAttr = c.getValue();
1204     return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1205            cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1206   }
1207   if (auto c = val.getDefiningOp<arith::ConstantIntOp>())
1208     return c.value() == 0;
1209   if (auto c = val.getDefiningOp<arith::ConstantFloatOp>())
1210     return c.value().isZero();
1211   return false;
1212 }
1213 
1214 /// Only returns false if we are certain this is a nonzero.
maybeZero(ExprId e) const1215 bool Merger::maybeZero(ExprId e) const {
1216   const auto &expr = exp(e);
1217   if (expr.kind == TensorExp::Kind::kInvariant) {
1218     // Note that this is different from isCertainZero() in a subtle
1219     // way by always returning true for non-constants.
1220     if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1221       ArrayAttr arrayAttr = c.getValue();
1222       return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1223              cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1224     }
1225     if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
1226       return c.value() == 0;
1227     if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>())
1228       return c.value().isZero();
1229   }
1230   return true;
1231 }
1232 
inferType(ExprId e,Value src) const1233 Type Merger::inferType(ExprId e, Value src) const {
1234   // Obtain the destination type from the cast node.
1235   Type dtp = exp(e).val.getType();
1236   // Inspect source type. For vector types, apply the same
1237   // vectorization to the destination type.
1238   if (auto vtp = dyn_cast<VectorType>(src.getType()))
1239     return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
1240   return dtp;
1241 }
1242 
1243 /// Ensures that the sparsifier can generate code for expression.
isAdmissibleBranchExp(Operation * op,Block * block,Value v)1244 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
1245   // Arguments are always admissible.
1246   if (isa<BlockArgument>(v))
1247     return true;
1248   // Accept index anywhere.
1249   Operation *def = v.getDefiningOp();
1250   if (isa<linalg::IndexOp>(def))
1251     return true;
1252   // Operation defined outside branch.
1253   if (def->getBlock() != block)
1254     return def->getBlock() != op->getBlock(); // invariant?
1255   // Operation defined within branch. Anything is accepted,
1256   // as long as all subexpressions are admissible.
1257   for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
1258     if (!isAdmissibleBranchExp(op, block, def->getOperand(i)))
1259       return false;
1260   return true;
1261 }
1262 
1263 /// Ensures that the sparsifier can generate code for branch.
isAdmissibleBranch(Operation * op,Region & region)1264 static bool isAdmissibleBranch(Operation *op, Region &region) {
1265   if (region.empty())
1266     return true;
1267   // Build the semi-ring branch semantics backward from yield.
1268   Operation *yield = region.front().getTerminator();
1269   assert(isa<YieldOp>(yield));
1270   return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
1271 }
1272 
1273 // Recognizes a direct GT comparison.
isGreater(TensorExp::Kind kind,Attribute attr)1274 static bool isGreater(TensorExp::Kind kind, Attribute attr) {
1275   if (kind == TensorExp::Kind::kCmpI) {
1276     auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr).getValue();
1277     return pred == arith::CmpIPredicate::ugt ||
1278            pred == arith::CmpIPredicate::sgt;
1279   }
1280   if (kind == TensorExp::Kind::kCmpF) {
1281     auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr).getValue();
1282     return pred == arith::CmpFPredicate::UGT ||
1283            pred == arith::CmpFPredicate::OGT;
1284   }
1285   return false;
1286 }
1287 
1288 std::pair<std::optional<ExprId>, bool>
buildTensorExp(linalg::GenericOp op,Value v)1289 Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1290   // Recursion leaves.
1291   if (auto arg = dyn_cast<BlockArgument>(v)) {
1292     const TensorId tid = makeTensorId(arg.getArgNumber());
1293     // Any argument of the generic op that is not marked as a scalar
1294     // argument is considered a tensor, indexed by the implicit loop
1295     // bounds. This includes rank-0 tensor arguments.
1296     if (arg.getOwner()->getParentOp() == op) {
1297       OpOperand &t = op->getOpOperand(tid);
1298       bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr;
1299       if (!op.isScalar(&t))
1300         return {addTensorExp(tid), hasSpDep};
1301       v = t.get(); // get scalar value
1302     }
1303     // Any other argument (marked as scalar argument for the generic op
1304     // or belonging to an enveloping op) is considered invariant.
1305     return {addInvariantExp(v), /*hasSpDep=*/false};
1306   }
1307 
1308   // Something defined outside is invariant.
1309   Operation *def = v.getDefiningOp();
1310   if (def->getBlock() != &op.getRegion().front())
1311     return {addInvariantExp(v), /*hasSpDep=*/false};
1312   // Construct index operations.
1313   if (def->getNumOperands() == 0) {
1314     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
1315       return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false};
1316   }
1317 
1318   // Construct unary operations if subexpression can be built.
1319   if (def->getNumOperands() == 1) {
1320     const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
1321     if (x.has_value()) {
1322       const ExprId e = *x;
1323       if (isa<math::AbsFOp>(def))
1324         return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep};
1325       if (isa<complex::AbsOp>(def))
1326         return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep};
1327       if (isa<math::AbsIOp>(def))
1328         return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep};
1329       if (isa<math::CeilOp>(def))
1330         return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep};
1331       if (isa<math::FloorOp>(def))
1332         return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep};
1333       if (isa<math::SqrtOp>(def))
1334         return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep};
1335       if (isa<complex::SqrtOp>(def))
1336         return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep};
1337       if (isa<math::ExpM1Op>(def))
1338         return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep};
1339       if (isa<complex::Expm1Op>(def))
1340         return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep};
1341       if (isa<math::Log1pOp>(def))
1342         return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep};
1343       if (isa<complex::Log1pOp>(def))
1344         return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep};
1345       if (isa<math::SinOp>(def))
1346         return {addExp(TensorExp::Kind::kSinF, e), hasSpDep};
1347       if (isa<complex::SinOp>(def))
1348         return {addExp(TensorExp::Kind::kSinC, e), hasSpDep};
1349       if (isa<math::TanhOp>(def))
1350         return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep};
1351       if (isa<complex::TanhOp>(def))
1352         return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep};
1353       if (isa<arith::NegFOp>(def))
1354         return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std
1355       if (isa<complex::NegOp>(def))
1356         return {addExp(TensorExp::Kind::kNegC, e), hasSpDep};
1357       if (isa<arith::TruncFOp>(def))
1358         return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep};
1359       if (isa<arith::ExtFOp>(def))
1360         return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep};
1361       if (isa<arith::FPToSIOp>(def))
1362         return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep};
1363       if (isa<arith::FPToUIOp>(def))
1364         return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep};
1365       if (isa<arith::SIToFPOp>(def))
1366         return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep};
1367       if (isa<arith::UIToFPOp>(def))
1368         return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep};
1369       if (isa<arith::ExtSIOp>(def))
1370         return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep};
1371       if (isa<arith::ExtUIOp>(def))
1372         return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep};
1373       if (isa<arith::IndexCastOp>(def))
1374         return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep};
1375       if (isa<arith::TruncIOp>(def))
1376         return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep};
1377       if (isa<complex::ImOp>(def))
1378         return {addExp(TensorExp::Kind::kCIm, e), hasSpDep};
1379       if (isa<complex::ReOp>(def))
1380         return {addExp(TensorExp::Kind::kCRe, e), hasSpDep};
1381       if (isa<arith::BitcastOp>(def))
1382         return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep};
1383       if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1384         if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
1385             isAdmissibleBranch(unop, unop.getAbsentRegion()))
1386           return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep};
1387       }
1388       if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1389         if (isAdmissibleBranch(selop, selop.getRegion()))
1390           return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep};
1391       }
1392     }
1393   }
1394 
1395   // Construct binary operations if subexpressions can be built.
1396   // See buildLattices() for an explanation of rejecting certain
1397   // division and shift operations.
1398   if (def->getNumOperands() == 2) {
1399     const auto [x, xSpVals] = buildTensorExp(op, def->getOperand(0));
1400     const auto [y, ySpVals] = buildTensorExp(op, def->getOperand(1));
1401     // For a conjunctive operation, it yields a "sparse" result if any operand
1402     // is sparse. For a disjunctive operation, it yields a "sparse" result if
1403     // all operands are sparse.
1404     bool conjSpVals = xSpVals || ySpVals;
1405     bool disjSpVals = xSpVals && ySpVals;
1406     if (x.has_value() && y.has_value()) {
1407       const ExprId e0 = *x;
1408       const ExprId e1 = *y;
1409       if (isa<arith::MulFOp>(def))
1410         return {addExp(TensorExp::Kind::kMulF, e0, e1), conjSpVals};
1411       if (isa<complex::MulOp>(def))
1412         return {addExp(TensorExp::Kind::kMulC, e0, e1), conjSpVals};
1413       if (isa<arith::MulIOp>(def))
1414         return {addExp(TensorExp::Kind::kMulI, e0, e1), conjSpVals};
1415       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1416         return {addExp(TensorExp::Kind::kDivF, e0, e1), conjSpVals};
1417       if (isa<complex::DivOp>(def) && !maybeZero(e1))
1418         return {addExp(TensorExp::Kind::kDivC, e0, e1), conjSpVals};
1419       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1420         return {addExp(TensorExp::Kind::kDivS, e0, e1), conjSpVals};
1421       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1422         return {addExp(TensorExp::Kind::kDivU, e0, e1), conjSpVals};
1423       if (isa<arith::AddFOp>(def))
1424         return {addExp(TensorExp::Kind::kAddF, e0, e1), disjSpVals};
1425       if (isa<complex::AddOp>(def))
1426         return {addExp(TensorExp::Kind::kAddC, e0, e1), disjSpVals};
1427       if (isa<arith::AddIOp>(def))
1428         return {addExp(TensorExp::Kind::kAddI, e0, e1), disjSpVals};
1429       if (isa<arith::SubFOp>(def))
1430         return {addExp(TensorExp::Kind::kSubF, e0, e1), disjSpVals};
1431       if (isa<complex::SubOp>(def))
1432         return {addExp(TensorExp::Kind::kSubC, e0, e1), disjSpVals};
1433       if (isa<arith::SubIOp>(def))
1434         return {addExp(TensorExp::Kind::kSubI, e0, e1), disjSpVals};
1435       if (isa<arith::AndIOp>(def))
1436         return {addExp(TensorExp::Kind::kAndI, e0, e1), conjSpVals};
1437       if (isa<arith::OrIOp>(def))
1438         return {addExp(TensorExp::Kind::kOrI, e0, e1), disjSpVals};
1439       if (isa<arith::XOrIOp>(def))
1440         return {addExp(TensorExp::Kind::kXorI, e0, e1), disjSpVals};
1441       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1442         return {addExp(TensorExp::Kind::kShrS, e0, e1), conjSpVals};
1443       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1444         return {addExp(TensorExp::Kind::kShrU, e0, e1), conjSpVals};
1445       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1446         return {addExp(TensorExp::Kind::kShlI, e0, e1), conjSpVals};
1447       if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
1448         if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1449             ci.getPredicate() == arith::CmpIPredicate::sle &&
1450             ci.getPredicate() == arith::CmpIPredicate::sge &&
1451             ci.getPredicate() == arith::CmpIPredicate::ule &&
1452             ci.getPredicate() == arith::CmpIPredicate::uge) {
1453           // We can not sparsify comparison with equal, this is because 0 <= 0
1454           // yields true, and thus densifies the result.
1455           return {std::nullopt, false};
1456         }
1457 
1458         auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
1459                         ci.getPredicateAttr());
1460         return {e, conjSpVals};
1461       }
1462       if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
1463         if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1464             cf.getPredicate() == arith::CmpFPredicate::OGE &&
1465             cf.getPredicate() == arith::CmpFPredicate::OLE &&
1466             cf.getPredicate() == arith::CmpFPredicate::ONE &&
1467             cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1468             cf.getPredicate() == arith::CmpFPredicate::UGE &&
1469             cf.getPredicate() == arith::CmpFPredicate::ULE &&
1470             cf.getPredicate() == arith::CmpFPredicate::ORD &&
1471             cf.getPredicate() == arith::CmpFPredicate::UNO) {
1472           // We can not sparsify comparison with equal, this is because 0 <= 0
1473           // yields true, and thus densifies the result.
1474           return {std::nullopt, false};
1475         }
1476         auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
1477                         cf.getPredicateAttr());
1478         return {e, conjSpVals};
1479       }
1480       if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1481         if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
1482             (binop.getLeftIdentity() ||
1483              isAdmissibleBranch(binop, binop.getLeftRegion())) &&
1484             (binop.getRightIdentity() ||
1485              isAdmissibleBranch(binop, binop.getRightRegion())))
1486           return {addExp(TensorExp::Kind::kBinary, e0, e1, def), conjSpVals};
1487       }
1488     }
1489   }
1490 
1491   // Construct ternary operations if subexpressions can be built.
1492   if (def->getNumOperands() == 3) {
1493     const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1494     const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1495     const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
1496     bool hasSpDep = xDepSp || yDepSp || zDepSp;
1497     if (x.has_value() && y.has_value() && z.has_value()) {
1498       const ExprId e0 = *x;
1499       const ExprId e1 = *y;
1500       if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1501         if (isAdmissibleBranch(redop, redop.getRegion()))
1502           return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
1503       }
1504       if (auto selop = dyn_cast<arith::SelectOp>(def)) {
1505         // Recognize an integral or floating-point ReLu(x) = Max(x, 0)
1506         // operation inside a very specific ternary select operation.
1507         // TODO: capture MIN/MAX/ABS/RELU structure in a more generic way
1508         const auto &cnd = exp(*x);
1509         if (isGreater(cnd.kind, cnd.attr) &&
1510             exp(*y).kind == TensorExp::Kind::kTensor &&
1511             exp(*z).kind == TensorExp::Kind::kInvariant &&
1512             isCertainZero(exp(*z).val)) {
1513           const auto &a = exp(cnd.children.e0);
1514           const auto &b = exp(cnd.children.e1);
1515           if (a.kind == TensorExp::Kind::kTensor &&
1516               a.tensor == exp(*y).tensor &&
1517               b.kind == TensorExp::Kind::kInvariant && isCertainZero(b.val)) {
1518             return {addExp(TensorExp::Kind::kRelu, *y, detail::kInvalidId,
1519                            nullptr, cnd.attr),
1520                     yDepSp};
1521           }
1522         }
1523       }
1524     }
1525   }
1526 
1527   // If we reach here, we are dealing with an operation that is not currently
1528   // sparsifiable. We can still generate code for it if all its operands only
1529   // have dense dependencies (i.e., all the values are loaded from dense
1530   // tensors).
1531   if (def->getNumResults() != 1) // only handle single result operation.
1532     return {std::nullopt, false};
1533   SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
1534   // Builds all the sub-expressions
1535   for (Value operand : def->getOperands())
1536     subExp.push_back(buildTensorExp(op, operand));
1537 
1538   if (llvm::all_of(subExp,
1539                    [](auto e) { return e.first.has_value() && !e.second; })) {
1540     // All the subexpressions can be built and has *no* sparse dependencies.
1541     if (subExp.size() == 2) {
1542       auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1543                       *subExp[1].first, def);
1544       return {e, false};
1545     }
1546     if (subExp.size() == 1) {
1547       auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1548                       detail::kInvalidId, def);
1549       return {e, false};
1550     }
1551   }
1552 
1553   // Cannot build.
1554   return {std::nullopt, false};
1555 }
1556 
insertYieldOp(RewriterBase & rewriter,Location loc,Region & region,ValueRange vals)1557 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1558                            ValueRange vals) {
1559   // Make a clone of overlap region.
1560   Region tmpRegion;
1561   IRMapping mapper;
1562   region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1563   Block &clonedBlock = tmpRegion.front();
1564   YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1565   // Merge cloned block and return yield value.
1566   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1567   rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
1568   Value val = clonedYield.getSingleResult();
1569   rewriter.eraseOp(clonedYield);
1570   rewriter.eraseOp(placeholder);
1571   return val;
1572 }
1573 
buildUnaryPresent(RewriterBase & rewriter,Location loc,Operation * op,Value v0)1574 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
1575                                Operation *op, Value v0) {
1576   if (!v0)
1577     // Empty input value must be propagated.
1578     return Value();
1579   UnaryOp unop = cast<UnaryOp>(op);
1580   Region &presentRegion = unop.getPresentRegion();
1581   if (presentRegion.empty())
1582     // Uninitialized Value() will be interpreted as missing data in the
1583     // output.
1584     return Value();
1585   return insertYieldOp(rewriter, loc, presentRegion, {v0});
1586 }
1587 
buildBinaryOverlap(RewriterBase & rewriter,Location loc,Operation * op,Value v0,Value v1)1588 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
1589                                 Operation *op, Value v0, Value v1) {
1590   if (!v0 || !v1)
1591     // Empty input values must be propagated.
1592     return Value();
1593   BinaryOp binop = cast<BinaryOp>(op);
1594   Region &overlapRegion = binop.getOverlapRegion();
1595   if (overlapRegion.empty())
1596     // Uninitialized Value() will be interpreted as missing data in the
1597     // output.
1598     return Value();
1599   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1600 }
1601 
buildRelu(RewriterBase & rewriter,Location loc,Value v0,Attribute attr)1602 static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0,
1603                        Attribute attr) {
1604   Type tp = v0.getType();
1605   auto zero =
1606       rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
1607   Value cmp;
1608   if (isa<FloatType>(tp)) {
1609     auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr);
1610     cmp = rewriter.create<arith::CmpFOp>(loc, pred, v0, zero);
1611   } else {
1612     auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr);
1613     cmp = rewriter.create<arith::CmpIOp>(loc, pred, v0, zero);
1614   }
1615   return rewriter.create<arith::SelectOp>(loc, cmp, v0, zero);
1616 }
1617 
buildExp(RewriterBase & rewriter,Location loc,ExprId e,Value v0,Value v1) const1618 Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
1619                        Value v1) const {
1620   const auto &expr = exp(e);
1621   switch (expr.kind) {
1622   // Leaf.
1623   case TensorExp::Kind::kTensor:
1624   case TensorExp::Kind::kInvariant:
1625   case TensorExp::Kind::kLoopVar:
1626   case TensorExp::Kind::kSynZero:
1627     llvm_unreachable("unexpected non-op");
1628   // Unary operations.
1629   case TensorExp::Kind::kAbsF:
1630     return rewriter.create<math::AbsFOp>(loc, v0);
1631   case TensorExp::Kind::kAbsC: {
1632     auto type = cast<ComplexType>(v0.getType());
1633     auto eltType = cast<FloatType>(type.getElementType());
1634     return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1635   }
1636   case TensorExp::Kind::kAbsI:
1637     return rewriter.create<math::AbsIOp>(loc, v0);
1638   case TensorExp::Kind::kCeilF:
1639     return rewriter.create<math::CeilOp>(loc, v0);
1640   case TensorExp::Kind::kFloorF:
1641     return rewriter.create<math::FloorOp>(loc, v0);
1642   case TensorExp::Kind::kSqrtF:
1643     return rewriter.create<math::SqrtOp>(loc, v0);
1644   case TensorExp::Kind::kSqrtC:
1645     return rewriter.create<complex::SqrtOp>(loc, v0);
1646   case TensorExp::Kind::kExpm1F:
1647     return rewriter.create<math::ExpM1Op>(loc, v0);
1648   case TensorExp::Kind::kExpm1C:
1649     return rewriter.create<complex::Expm1Op>(loc, v0);
1650   case TensorExp::Kind::kLog1pF:
1651     return rewriter.create<math::Log1pOp>(loc, v0);
1652   case TensorExp::Kind::kLog1pC:
1653     return rewriter.create<complex::Log1pOp>(loc, v0);
1654   case TensorExp::Kind::kRelu:
1655     return buildRelu(rewriter, loc, v0, expr.attr);
1656   case TensorExp::Kind::kSinF:
1657     return rewriter.create<math::SinOp>(loc, v0);
1658   case TensorExp::Kind::kSinC:
1659     return rewriter.create<complex::SinOp>(loc, v0);
1660   case TensorExp::Kind::kTanhF:
1661     return rewriter.create<math::TanhOp>(loc, v0);
1662   case TensorExp::Kind::kTanhC:
1663     return rewriter.create<complex::TanhOp>(loc, v0);
1664   case TensorExp::Kind::kNegF:
1665     return rewriter.create<arith::NegFOp>(loc, v0);
1666   case TensorExp::Kind::kNegC:
1667     return rewriter.create<complex::NegOp>(loc, v0);
1668   case TensorExp::Kind::kNegI: // no negi in std
1669     return rewriter.create<arith::SubIOp>(
1670         loc,
1671         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1672                                            rewriter.getZeroAttr(v0.getType())),
1673         v0);
1674   case TensorExp::Kind::kTruncF:
1675     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1676   case TensorExp::Kind::kExtF:
1677     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1678   case TensorExp::Kind::kCastFS:
1679     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1680   case TensorExp::Kind::kCastFU:
1681     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1682   case TensorExp::Kind::kCastSF:
1683     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1684   case TensorExp::Kind::kCastUF:
1685     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1686   case TensorExp::Kind::kCastS:
1687     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1688   case TensorExp::Kind::kCastU:
1689     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1690   case TensorExp::Kind::kCastIdx:
1691     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1692   case TensorExp::Kind::kTruncI:
1693     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1694   case TensorExp::Kind::kCIm: {
1695     auto type = cast<ComplexType>(v0.getType());
1696     auto eltType = cast<FloatType>(type.getElementType());
1697     return rewriter.create<complex::ImOp>(loc, eltType, v0);
1698   }
1699   case TensorExp::Kind::kCRe: {
1700     auto type = cast<ComplexType>(v0.getType());
1701     auto eltType = cast<FloatType>(type.getElementType());
1702     return rewriter.create<complex::ReOp>(loc, eltType, v0);
1703   }
1704   case TensorExp::Kind::kBitCast:
1705     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1706   // Binary operations.
1707   case TensorExp::Kind::kMulF:
1708     return rewriter.create<arith::MulFOp>(loc, v0, v1);
1709   case TensorExp::Kind::kMulC:
1710     return rewriter.create<complex::MulOp>(loc, v0, v1);
1711   case TensorExp::Kind::kMulI:
1712     return rewriter.create<arith::MulIOp>(loc, v0, v1);
1713   case TensorExp::Kind::kDivF:
1714     return rewriter.create<arith::DivFOp>(loc, v0, v1);
1715   case TensorExp::Kind::kDivC:
1716     return rewriter.create<complex::DivOp>(loc, v0, v1);
1717   case TensorExp::Kind::kDivS:
1718     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1719   case TensorExp::Kind::kDivU:
1720     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1721   case TensorExp::Kind::kAddF:
1722     return rewriter.create<arith::AddFOp>(loc, v0, v1);
1723   case TensorExp::Kind::kAddC:
1724     return rewriter.create<complex::AddOp>(loc, v0, v1);
1725   case TensorExp::Kind::kAddI:
1726     return rewriter.create<arith::AddIOp>(loc, v0, v1);
1727   case TensorExp::Kind::kSubF:
1728     return rewriter.create<arith::SubFOp>(loc, v0, v1);
1729   case TensorExp::Kind::kSubC:
1730     return rewriter.create<complex::SubOp>(loc, v0, v1);
1731   case TensorExp::Kind::kSubI:
1732     return rewriter.create<arith::SubIOp>(loc, v0, v1);
1733   case TensorExp::Kind::kAndI:
1734     return rewriter.create<arith::AndIOp>(loc, v0, v1);
1735   case TensorExp::Kind::kOrI:
1736     return rewriter.create<arith::OrIOp>(loc, v0, v1);
1737   case TensorExp::Kind::kXorI:
1738     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1739   case TensorExp::Kind::kShrS:
1740     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1741   case TensorExp::Kind::kShrU:
1742     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1743   case TensorExp::Kind::kShlI:
1744     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1745   case TensorExp::Kind::kCmpI: {
1746     auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr);
1747     return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1);
1748   }
1749   case TensorExp::Kind::kCmpF: {
1750     auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr);
1751     return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1);
1752   }
1753   case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
1754     return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(),
1755                          {v0});
1756   case TensorExp::Kind::kUnary:
1757     return buildUnaryPresent(rewriter, loc, expr.op, v0);
1758   case TensorExp::Kind::kSelect:
1759     return insertYieldOp(rewriter, loc,
1760                          cast<sparse_tensor::SelectOp>(expr.op).getRegion(),
1761                          {v0});
1762   case TensorExp::Kind::kBinary:
1763     return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
1764   case TensorExp::Kind::kReduce: {
1765     ReduceOp redOp = cast<ReduceOp>(expr.op);
1766     return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1767   }
1768   case TensorExp::Kind::kDenseOp: {
1769     Operation *actualOp = expr.op;
1770     IRMapping mapping;
1771     mapping.map(actualOp->getOperand(0), v0);
1772     if (actualOp->getNumOperands() == 2)
1773       mapping.map(actualOp->getOperand(1), v1);
1774     return rewriter.clone(*actualOp, mapping)->getResult(0);
1775   }
1776   }
1777   llvm_unreachable("unexpected expression kind in build");
1778 }
1779 
1780 } // namespace sparse_tensor
1781 } // namespace mlir
1782