xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (revision 70e227a404e51f9248c7ad5d79953805b2afacb4)
1744146f6SGus Smith //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2744146f6SGus Smith //
3744146f6SGus Smith // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4744146f6SGus Smith // See https://llvm.org/LICENSE.txt for license information.
5744146f6SGus Smith // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6744146f6SGus Smith //
7744146f6SGus Smith //===----------------------------------------------------------------------===//
8744146f6SGus Smith 
9744146f6SGus Smith #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
11736c1b66SAart Bik #include "mlir/Dialect/Complex/IR/Complex.h"
12eda6f907SRiver Riddle #include "mlir/Dialect/Math/IR/Math.h"
1390c2af57SMehdi Amini #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14744146f6SGus Smith 
15557b101cSAart Bik #include "mlir/IR/Operation.h"
16557b101cSAart Bik #include "llvm/Support/Debug.h"
17a1fe1f5fSKazu Hirata #include <optional>
18557b101cSAart Bik 
19744146f6SGus Smith namespace mlir {
20744146f6SGus Smith namespace sparse_tensor {
21744146f6SGus Smith 
22faa75f94SPeiming Liu enum class ExpArity {
23faa75f94SPeiming Liu   kNullary,
24faa75f94SPeiming Liu   kUnary,
25faa75f94SPeiming Liu   kBinary,
26faa75f94SPeiming Liu };
27faa75f94SPeiming Liu 
getExpArity(TensorExp::Kind k)281f58ae80Swren romano static ExpArity getExpArity(TensorExp::Kind k) {
29faa75f94SPeiming Liu   switch (k) {
30faa75f94SPeiming Liu   // Leaf.
311f58ae80Swren romano   case TensorExp::Kind::kTensor:
321f58ae80Swren romano   case TensorExp::Kind::kInvariant:
331f58ae80Swren romano   case TensorExp::Kind::kLoopVar:
34faf7cd97SPeiming Liu   case TensorExp::Kind::kSynZero:
35faa75f94SPeiming Liu     return ExpArity::kNullary;
361f58ae80Swren romano   case TensorExp::Kind::kAbsF:
371f58ae80Swren romano   case TensorExp::Kind::kAbsC:
381f58ae80Swren romano   case TensorExp::Kind::kAbsI:
391f58ae80Swren romano   case TensorExp::Kind::kCeilF:
401f58ae80Swren romano   case TensorExp::Kind::kFloorF:
411f58ae80Swren romano   case TensorExp::Kind::kSqrtF:
421f58ae80Swren romano   case TensorExp::Kind::kSqrtC:
431f58ae80Swren romano   case TensorExp::Kind::kExpm1F:
441f58ae80Swren romano   case TensorExp::Kind::kExpm1C:
451f58ae80Swren romano   case TensorExp::Kind::kLog1pF:
461f58ae80Swren romano   case TensorExp::Kind::kLog1pC:
47*70e227a4SAart Bik   case TensorExp::Kind::kRelu:
481f58ae80Swren romano   case TensorExp::Kind::kSinF:
491f58ae80Swren romano   case TensorExp::Kind::kSinC:
501f58ae80Swren romano   case TensorExp::Kind::kTanhF:
511f58ae80Swren romano   case TensorExp::Kind::kTanhC:
521f58ae80Swren romano   case TensorExp::Kind::kTruncF:
531f58ae80Swren romano   case TensorExp::Kind::kExtF:
541f58ae80Swren romano   case TensorExp::Kind::kCastFS:
551f58ae80Swren romano   case TensorExp::Kind::kCastFU:
561f58ae80Swren romano   case TensorExp::Kind::kCastSF:
571f58ae80Swren romano   case TensorExp::Kind::kCastUF:
581f58ae80Swren romano   case TensorExp::Kind::kCastS:
591f58ae80Swren romano   case TensorExp::Kind::kCastU:
601f58ae80Swren romano   case TensorExp::Kind::kCastIdx:
611f58ae80Swren romano   case TensorExp::Kind::kTruncI:
621f58ae80Swren romano   case TensorExp::Kind::kCIm:
631f58ae80Swren romano   case TensorExp::Kind::kCRe:
641f58ae80Swren romano   case TensorExp::Kind::kBitCast:
651f58ae80Swren romano   case TensorExp::Kind::kBinaryBranch:
661f58ae80Swren romano   case TensorExp::Kind::kUnary:
671f58ae80Swren romano   case TensorExp::Kind::kSelect:
681f58ae80Swren romano   case TensorExp::Kind::kNegF:
691f58ae80Swren romano   case TensorExp::Kind::kNegC:
701f58ae80Swren romano   case TensorExp::Kind::kNegI:
71faa75f94SPeiming Liu     return ExpArity::kUnary;
72faa75f94SPeiming Liu   // Binary operations.
731f58ae80Swren romano   case TensorExp::Kind::kDivF:
741f58ae80Swren romano   case TensorExp::Kind::kDivC:
751f58ae80Swren romano   case TensorExp::Kind::kDivS:
761f58ae80Swren romano   case TensorExp::Kind::kDivU:
771f58ae80Swren romano   case TensorExp::Kind::kShrS:
781f58ae80Swren romano   case TensorExp::Kind::kShrU:
791f58ae80Swren romano   case TensorExp::Kind::kShlI:
801f58ae80Swren romano   case TensorExp::Kind::kMulF:
811f58ae80Swren romano   case TensorExp::Kind::kMulC:
821f58ae80Swren romano   case TensorExp::Kind::kMulI:
831f58ae80Swren romano   case TensorExp::Kind::kAndI:
841f58ae80Swren romano   case TensorExp::Kind::kAddF:
851f58ae80Swren romano   case TensorExp::Kind::kAddC:
861f58ae80Swren romano   case TensorExp::Kind::kAddI:
871f58ae80Swren romano   case TensorExp::Kind::kOrI:
881f58ae80Swren romano   case TensorExp::Kind::kXorI:
891f58ae80Swren romano   case TensorExp::Kind::kBinary:
901f58ae80Swren romano   case TensorExp::Kind::kReduce:
911f58ae80Swren romano   case TensorExp::Kind::kSubF:
921f58ae80Swren romano   case TensorExp::Kind::kSubC:
931f58ae80Swren romano   case TensorExp::Kind::kSubI:
94faf7cd97SPeiming Liu   case TensorExp::Kind::kCmpF:
95faf7cd97SPeiming Liu   case TensorExp::Kind::kCmpI:
96df11a2b4SPeiming Liu   case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands
97faa75f94SPeiming Liu     return ExpArity::kBinary;
98faa75f94SPeiming Liu   }
99faa75f94SPeiming Liu   llvm_unreachable("unexpected kind");
100faa75f94SPeiming Liu }
101faa75f94SPeiming Liu 
102e2d3db42SAart Bik //===----------------------------------------------------------------------===//
103b8a021dbSAart Bik // Constructors.
104e2d3db42SAart Bik //===----------------------------------------------------------------------===//
105b8a021dbSAart Bik 
TensorExp(TensorExp::Kind k,unsigned x,ExprId y,Value v,Operation * o,Attribute a)10646a384dfSwren romano TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
107faf7cd97SPeiming Liu                      Operation *o, Attribute a)
108*70e227a4SAart Bik     : kind(k), val(v), op(o), attr(a) {
109b8a021dbSAart Bik   switch (kind) {
11006aa6ec8SAart Bik   // Leaf.
1111f58ae80Swren romano   case TensorExp::Kind::kTensor:
11274c54206Swren romano     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
113b8a021dbSAart Bik     tensor = x;
1140280681bSwren romano     return;
115faf7cd97SPeiming Liu   case TensorExp::Kind::kSynZero:
116faf7cd97SPeiming Liu     assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o);
117faf7cd97SPeiming Liu     return;
1181f58ae80Swren romano   case TensorExp::Kind::kInvariant:
11974c54206Swren romano     assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o);
1200280681bSwren romano     return;
1211f58ae80Swren romano   case TensorExp::Kind::kLoopVar:
12274c54206Swren romano     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
123b8cf7af9Swren romano     loop = x;
1240280681bSwren romano     return;
12506aa6ec8SAart Bik   // Unary operations.
1261f58ae80Swren romano   case TensorExp::Kind::kAbsF:
1271f58ae80Swren romano   case TensorExp::Kind::kAbsC:
1281f58ae80Swren romano   case TensorExp::Kind::kAbsI:
1291f58ae80Swren romano   case TensorExp::Kind::kCeilF:
1301f58ae80Swren romano   case TensorExp::Kind::kFloorF:
1311f58ae80Swren romano   case TensorExp::Kind::kSqrtF:
1321f58ae80Swren romano   case TensorExp::Kind::kSqrtC:
1331f58ae80Swren romano   case TensorExp::Kind::kExpm1F:
1341f58ae80Swren romano   case TensorExp::Kind::kExpm1C:
1351f58ae80Swren romano   case TensorExp::Kind::kLog1pF:
1361f58ae80Swren romano   case TensorExp::Kind::kLog1pC:
137*70e227a4SAart Bik   case TensorExp::Kind::kRelu:
1381f58ae80Swren romano   case TensorExp::Kind::kSinF:
1391f58ae80Swren romano   case TensorExp::Kind::kSinC:
1401f58ae80Swren romano   case TensorExp::Kind::kTanhF:
1411f58ae80Swren romano   case TensorExp::Kind::kTanhC:
1421f58ae80Swren romano   case TensorExp::Kind::kNegF:
1431f58ae80Swren romano   case TensorExp::Kind::kNegC:
1441f58ae80Swren romano   case TensorExp::Kind::kNegI:
1451f58ae80Swren romano   case TensorExp::Kind::kCIm:
1461f58ae80Swren romano   case TensorExp::Kind::kCRe:
14774c54206Swren romano     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
148123e8dfcSAart Bik     children.e0 = x;
149123e8dfcSAart Bik     children.e1 = y;
1500280681bSwren romano     return;
1511f58ae80Swren romano   case TensorExp::Kind::kTruncF:
1521f58ae80Swren romano   case TensorExp::Kind::kExtF:
1531f58ae80Swren romano   case TensorExp::Kind::kCastFS:
1541f58ae80Swren romano   case TensorExp::Kind::kCastFU:
1551f58ae80Swren romano   case TensorExp::Kind::kCastSF:
1561f58ae80Swren romano   case TensorExp::Kind::kCastUF:
1571f58ae80Swren romano   case TensorExp::Kind::kCastS:
1581f58ae80Swren romano   case TensorExp::Kind::kCastU:
1591f58ae80Swren romano   case TensorExp::Kind::kCastIdx:
1601f58ae80Swren romano   case TensorExp::Kind::kTruncI:
1611f58ae80Swren romano   case TensorExp::Kind::kBitCast:
16274c54206Swren romano     assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o);
1632c332660SJim Kitchen     children.e0 = x;
1642c332660SJim Kitchen     children.e1 = y;
1650280681bSwren romano     return;
1661f58ae80Swren romano   case TensorExp::Kind::kBinaryBranch:
1671f58ae80Swren romano   case TensorExp::Kind::kSelect:
16874c54206Swren romano     assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o);
1692c332660SJim Kitchen     children.e0 = x;
1702c332660SJim Kitchen     children.e1 = y;
1710280681bSwren romano     return;
1721f58ae80Swren romano   case TensorExp::Kind::kUnary:
1732c332660SJim Kitchen     // No assertion on y can be made, as the branching paths involve both
174b8cf7af9Swren romano     // a unary (`mapSet`) and binary (`disjSet`) pathway.
17574c54206Swren romano     assert(x != detail::kInvalidId && !v && o);
1762c332660SJim Kitchen     children.e0 = x;
1772c332660SJim Kitchen     children.e1 = y;
1780280681bSwren romano     return;
17906aa6ec8SAart Bik   // Binary operations.
1801f58ae80Swren romano   case TensorExp::Kind::kMulF:
1811f58ae80Swren romano   case TensorExp::Kind::kMulC:
1821f58ae80Swren romano   case TensorExp::Kind::kMulI:
1831f58ae80Swren romano   case TensorExp::Kind::kDivF:
1841f58ae80Swren romano   case TensorExp::Kind::kDivC:
1851f58ae80Swren romano   case TensorExp::Kind::kDivS:
1861f58ae80Swren romano   case TensorExp::Kind::kDivU:
1871f58ae80Swren romano   case TensorExp::Kind::kAddF:
1881f58ae80Swren romano   case TensorExp::Kind::kAddC:
1891f58ae80Swren romano   case TensorExp::Kind::kAddI:
1901f58ae80Swren romano   case TensorExp::Kind::kSubF:
1911f58ae80Swren romano   case TensorExp::Kind::kSubC:
1921f58ae80Swren romano   case TensorExp::Kind::kSubI:
1931f58ae80Swren romano   case TensorExp::Kind::kAndI:
1941f58ae80Swren romano   case TensorExp::Kind::kOrI:
1951f58ae80Swren romano   case TensorExp::Kind::kXorI:
1961f58ae80Swren romano   case TensorExp::Kind::kShrS:
1971f58ae80Swren romano   case TensorExp::Kind::kShrU:
1981f58ae80Swren romano   case TensorExp::Kind::kShlI:
19974c54206Swren romano     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
200e2d3db42SAart Bik     children.e0 = x;
201e2d3db42SAart Bik     children.e1 = y;
2020280681bSwren romano     return;
203faf7cd97SPeiming Liu   case TensorExp::Kind::kCmpF:
204faf7cd97SPeiming Liu   case TensorExp::Kind::kCmpI:
205faf7cd97SPeiming Liu     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
206faf7cd97SPeiming Liu     children.e0 = x;
207faf7cd97SPeiming Liu     children.e1 = y;
208faf7cd97SPeiming Liu     return;
2091f58ae80Swren romano   case TensorExp::Kind::kBinary:
2101f58ae80Swren romano   case TensorExp::Kind::kReduce:
21174c54206Swren romano     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o);
212b8a021dbSAart Bik     children.e0 = x;
213b8a021dbSAart Bik     children.e1 = y;
2140280681bSwren romano     return;
215df11a2b4SPeiming Liu   case TensorExp::Kind::kDenseOp:
216df11a2b4SPeiming Liu     assert(x != detail::kInvalidId && !v && o);
217df11a2b4SPeiming Liu     children.e0 = x;
218df11a2b4SPeiming Liu     children.e1 = y;
219df11a2b4SPeiming Liu     return;
220b8a021dbSAart Bik   }
2210280681bSwren romano   llvm_unreachable("unexpected kind");
222b8a021dbSAart Bik }
223b8a021dbSAart Bik 
Merger(unsigned numInputOutputTensors,unsigned numLoops,unsigned maxLvlRank)224ccd923e3SPeiming Liu Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
225ccd923e3SPeiming Liu                unsigned maxLvlRank)
226b8cf7af9Swren romano     : outTensor(numInputOutputTensors - 1),
227b8cf7af9Swren romano       syntheticTensor(numInputOutputTensors),
228ccd923e3SPeiming Liu       numTensors(numInputOutputTensors + 1), numLoops(numLoops),
229ccd923e3SPeiming Liu       hasSparseOut(false),
230aaf91645SPeiming Liu       lvlTypes(numTensors,
231aaf91645SPeiming Liu                std::vector<LevelType>(numLoops, LevelFormat::Undef)),
232b8cf7af9Swren romano       loopToLvl(numTensors,
233b8cf7af9Swren romano                 std::vector<std::optional<Level>>(numLoops, std::nullopt)),
234b8cf7af9Swren romano       lvlToLoop(numTensors,
2352b21327fSPeiming Liu                 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
2361dd387e1SAart Bik       loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>(
237d03805f2SPeiming Liu                                          numTensors, std::nullopt)),
238e015d385SPeiming Liu       levelToDependentLoop(numTensors,
239e015d385SPeiming Liu                            std::vector<std::vector<LoopCoeffPair>>(
240e015d385SPeiming Liu                                maxLvlRank, std::vector<LoopCoeffPair>())),
241d03805f2SPeiming Liu       loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
2422e1caa47SAart Bik 
243e2d3db42SAart Bik //===----------------------------------------------------------------------===//
244266a7414SAart Bik // Lattice methods.
245e2d3db42SAart Bik //===----------------------------------------------------------------------===//
246266a7414SAart Bik 
addTensorExp(TensorId t)24746a384dfSwren romano ExprId Merger::addTensorExp(TensorId t) {
24846a384dfSwren romano   assert(isValidTensorId(t));
24946a384dfSwren romano   const ExprId eNew(tensorExps.size());
25046a384dfSwren romano   tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId,
251faf7cd97SPeiming Liu                           Value(), nullptr, nullptr);
25246a384dfSwren romano   return eNew;
25346a384dfSwren romano }
25446a384dfSwren romano 
addLoopVarExp(LoopId i)25546a384dfSwren romano ExprId Merger::addLoopVarExp(LoopId i) {
25646a384dfSwren romano   assert(isValidLoopId(i));
25746a384dfSwren romano   const ExprId eNew(tensorExps.size());
25846a384dfSwren romano   tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId,
259faf7cd97SPeiming Liu                           Value(), nullptr, nullptr);
26046a384dfSwren romano   return eNew;
26146a384dfSwren romano }
26246a384dfSwren romano 
addInvariantExp(Value v)26346a384dfSwren romano ExprId Merger::addInvariantExp(Value v) {
26446a384dfSwren romano   const ExprId eNew(tensorExps.size());
26546a384dfSwren romano   tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId,
266faf7cd97SPeiming Liu                           detail::kInvalidId, v, nullptr, nullptr);
26746a384dfSwren romano   return eNew;
26846a384dfSwren romano }
26946a384dfSwren romano 
addSynZeroExp()270faf7cd97SPeiming Liu ExprId Merger::addSynZeroExp() {
27146a384dfSwren romano   const ExprId eNew(tensorExps.size());
272faf7cd97SPeiming Liu   tensorExps.emplace_back(TensorExp::Kind::kSynZero, detail::kInvalidId,
273faf7cd97SPeiming Liu                           detail::kInvalidId, Value(), nullptr, nullptr);
27446a384dfSwren romano   return eNew;
27546a384dfSwren romano }
27646a384dfSwren romano 
addExp(TensorExp::Kind k,ExprId e0,ExprId e1,Operation * op,Attribute attr)277faf7cd97SPeiming Liu ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op,
278faf7cd97SPeiming Liu                       Attribute attr) {
27946a384dfSwren romano   assert(k > TensorExp::Kind::kLoopVar);
28046a384dfSwren romano   const ExprId eNew(tensorExps.size());
281faf7cd97SPeiming Liu   tensorExps.emplace_back(k, e0, e1, Value(), op, attr);
282faf7cd97SPeiming Liu   return eNew;
283faf7cd97SPeiming Liu }
284faf7cd97SPeiming Liu 
addExp(TensorExp::Kind k,ExprId e,Value v,Operation * op,Attribute attr)285faf7cd97SPeiming Liu ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op,
286faf7cd97SPeiming Liu                       Attribute attr) {
287faf7cd97SPeiming Liu   assert(k > TensorExp::Kind::kLoopVar);
288faf7cd97SPeiming Liu   const ExprId eNew(tensorExps.size());
289faf7cd97SPeiming Liu   tensorExps.emplace_back(k, e, detail::kInvalidId, v, op, attr);
29046a384dfSwren romano   return eNew;
291744146f6SGus Smith }
292744146f6SGus Smith 
addLat(TensorId t,LoopId i,ExprId e)293b8cf7af9Swren romano LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {
29446a384dfSwren romano   const LatPointId pNew(latPoints.size());
29546a384dfSwren romano   const unsigned size = numLoops * numTensors;
29646a384dfSwren romano   const TensorLoopId b = makeTensorLoopId(t, i);
29746a384dfSwren romano   latPoints.emplace_back(size, e);
29846a384dfSwren romano   latPoints[pNew].bits.set(b);
29946a384dfSwren romano   return pNew;
300744146f6SGus Smith }
301744146f6SGus Smith 
addLat(const BitVector & bits,ExprId e)30213e9afd1Swren romano LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
30313e9afd1Swren romano   assert(bits.size() == numLoops * numTensors);
30446a384dfSwren romano   const LatPointId pNew(latPoints.size());
30513e9afd1Swren romano   latPoints.emplace_back(bits, e);
30646a384dfSwren romano   return pNew;
30713e9afd1Swren romano }
30813e9afd1Swren romano 
addSet()309b8cf7af9Swren romano LatSetId Merger::addSet() {
31046a384dfSwren romano   const LatSetId sNew(latSets.size());
3110e1708ffSAart Bik   latSets.emplace_back();
31246a384dfSwren romano   return sNew;
313744146f6SGus Smith }
314744146f6SGus Smith 
conjLat(ExprId e,LatPointId p0,LatPointId p1,Operation * op)315faf7cd97SPeiming Liu LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1,
3162c332660SJim Kitchen                            Operation *op) {
317faf7cd97SPeiming Liu   TensorExp::Kind kind = exp(e).kind;
318faf7cd97SPeiming Liu   Attribute attr = exp(e).attr;
31946a384dfSwren romano   const LatPointId pNew(latPoints.size());
32046a384dfSwren romano   const auto &point0 = lat(p0);
32146a384dfSwren romano   const auto &point1 = lat(p1);
32246a384dfSwren romano   BitVector bits(point0.bits);
32346a384dfSwren romano   bits |= point1.bits;
324faf7cd97SPeiming Liu   const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr);
325faf7cd97SPeiming Liu   latPoints.emplace_back(bits, ne);
32646a384dfSwren romano   return pNew;
327744146f6SGus Smith }
328744146f6SGus Smith 
conjSet(ExprId e,LatSetId s0,LatSetId s1,Operation * op)329faf7cd97SPeiming Liu LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
33046a384dfSwren romano   const LatSetId sNew = addSet();
33146a384dfSwren romano   auto &setNew = latSets[sNew];
33246a384dfSwren romano   for (const LatPointId p0 : set(s0))
33346a384dfSwren romano     for (const LatPointId p1 : set(s1))
334faf7cd97SPeiming Liu       setNew.push_back(conjLat(e, p0, p1, op));
33546a384dfSwren romano   return sNew;
336744146f6SGus Smith }
337744146f6SGus Smith 
disjSet(ExprId e,LatSetId s0,LatSetId s1,Operation * op)338faf7cd97SPeiming Liu LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
339faf7cd97SPeiming Liu   const LatSetId sNew = conjSet(e, s0, s1, op);
340faf7cd97SPeiming Liu   TensorExp::Kind kind = exp(e).kind;
341123e8dfcSAart Bik   // Followed by all in s0.
34246a384dfSwren romano   latSets[sNew].append(latSets[s0]);
343123e8dfcSAart Bik   // Map binary 0-y to unary -y.
3442c332660SJim Kitchen   // TODO: move this if-else logic into buildLattices
3451f58ae80Swren romano   if (kind == TensorExp::Kind::kSubF)
3461f58ae80Swren romano     s1 = mapSet(TensorExp::Kind::kNegF, s1);
3471f58ae80Swren romano   else if (kind == TensorExp::Kind::kSubC)
3481f58ae80Swren romano     s1 = mapSet(TensorExp::Kind::kNegC, s1);
3491f58ae80Swren romano   else if (kind == TensorExp::Kind::kSubI)
3501f58ae80Swren romano     s1 = mapSet(TensorExp::Kind::kNegI, s1);
351123e8dfcSAart Bik   // Followed by all in s1.
35246a384dfSwren romano   latSets[sNew].append(latSets[s1]);
35346a384dfSwren romano   return sNew;
354744146f6SGus Smith }
355744146f6SGus Smith 
disjSetWithZero(ExprId e,LatSetId s0,LatSetId s1)356faf7cd97SPeiming Liu LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) {
357faf7cd97SPeiming Liu   assert(exp(e).kind == TensorExp::Kind::kCmpI ||
358faf7cd97SPeiming Liu          exp(e).kind == TensorExp::Kind::kCmpF);
359faf7cd97SPeiming Liu   const LatSetId sNew = conjSet(e, s0, s1, nullptr);
360faf7cd97SPeiming Liu 
361faf7cd97SPeiming Liu   ExprId e0 = exp(e).children.e0;
362faf7cd97SPeiming Liu   ExprId e1 = exp(e).children.e1;
363faf7cd97SPeiming Liu   if (exp(e0).kind == TensorExp::Kind::kSynZero ||
364faf7cd97SPeiming Liu       exp(e1).kind == TensorExp::Kind::kSynZero) {
365faf7cd97SPeiming Liu     // lhs and rhs can't be synthetic zero at the same time.
366faf7cd97SPeiming Liu     assert(exp(e0).kind != exp(e1).kind);
367faf7cd97SPeiming Liu     // If one of the operands has already been assigned to zero (the
368faf7cd97SPeiming Liu     // element is absent in the corresponding operand), then we do not
369faf7cd97SPeiming Liu     // need to build disjunctive set for it.
370faf7cd97SPeiming Liu     return sNew;
371faf7cd97SPeiming Liu   }
372faf7cd97SPeiming Liu 
373faf7cd97SPeiming Liu   auto lhsSet = mapBinWithSynZeroSet(e, s0, false);
374faf7cd97SPeiming Liu   auto rhsSet = mapBinWithSynZeroSet(e, s1, true);
375faf7cd97SPeiming Liu   latSets[sNew].append(latSets[lhsSet]);
376faf7cd97SPeiming Liu   latSets[sNew].append(latSets[rhsSet]);
377faf7cd97SPeiming Liu   return sNew;
378faf7cd97SPeiming Liu }
379faf7cd97SPeiming Liu 
combiSet(ExprId e,LatSetId s0,LatSetId s1,Operation * orig,bool includeLeft,TensorExp::Kind ltrans,Operation * opleft,bool includeRight,TensorExp::Kind rtrans,Operation * opright)380faf7cd97SPeiming Liu LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
381faf7cd97SPeiming Liu                           bool includeLeft, TensorExp::Kind ltrans,
382faf7cd97SPeiming Liu                           Operation *opleft, bool includeRight,
383faf7cd97SPeiming Liu                           TensorExp::Kind rtrans, Operation *opright) {
384*70e227a4SAart Bik   Attribute a = exp(e).attr;
385faf7cd97SPeiming Liu   const LatSetId sNew = conjSet(e, s0, s1, orig);
3862c332660SJim Kitchen   // Left Region.
3872c332660SJim Kitchen   if (includeLeft) {
3882c332660SJim Kitchen     if (opleft)
389*70e227a4SAart Bik       s0 = mapSet(ltrans, s0, Value(), opleft, a);
39046a384dfSwren romano     latSets[sNew].append(latSets[s0]);
3912c332660SJim Kitchen   }
3922c332660SJim Kitchen   // Right Region.
3932c332660SJim Kitchen   if (includeRight) {
3942c332660SJim Kitchen     if (opright)
395*70e227a4SAart Bik       s1 = mapSet(rtrans, s1, Value(), opright, a);
39646a384dfSwren romano     latSets[sNew].append(latSets[s1]);
3972c332660SJim Kitchen   }
39846a384dfSwren romano   return sNew;
3992c332660SJim Kitchen }
4002c332660SJim Kitchen 
mapSet(TensorExp::Kind kind,LatSetId s0,Value v,Operation * op,Attribute a)4011f58ae80Swren romano LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
402*70e227a4SAart Bik                         Operation *op, Attribute a) {
403df11a2b4SPeiming Liu   assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
404df11a2b4SPeiming Liu          TensorExp::Kind::kDenseOp == kind);
40546a384dfSwren romano   const LatSetId sNew = addSet();
40646a384dfSwren romano   auto &setNew = latSets[sNew];
40746a384dfSwren romano   for (const LatPointId p : set(s0)) {
40846a384dfSwren romano     const auto &point = latPoints[p];
409*70e227a4SAart Bik     setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op, a)));
410b8a021dbSAart Bik   }
41146a384dfSwren romano   return sNew;
412b8a021dbSAart Bik }
413b8a021dbSAart Bik 
mapBinWithSynZeroSet(ExprId e,LatSetId s0,bool lhsZero)414faf7cd97SPeiming Liu LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) {
415faf7cd97SPeiming Liu   TensorExp::Kind kind = exp(e).kind;
416faf7cd97SPeiming Liu   Attribute a = exp(e).attr;
417faf7cd97SPeiming Liu   assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI);
418faf7cd97SPeiming Liu   // Must be a binary operation.
419faf7cd97SPeiming Liu   const LatSetId sNew = addSet();
420faf7cd97SPeiming Liu   auto &setNew = latSets[sNew];
421faf7cd97SPeiming Liu   const ExprId zeroExp = addSynZeroExp();
422faf7cd97SPeiming Liu   for (const LatPointId p : set(s0)) {
423faf7cd97SPeiming Liu     const auto &point = latPoints[p];
424faf7cd97SPeiming Liu     ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a)
425faf7cd97SPeiming Liu                             : addExp(kind, point.exp, zeroExp, nullptr, a);
426faf7cd97SPeiming Liu     setNew.push_back(addLat(point.bits, newExp));
427faf7cd97SPeiming Liu   }
428faf7cd97SPeiming Liu   return sNew;
429faf7cd97SPeiming Liu }
430faf7cd97SPeiming Liu 
optimizeSet(LatSetId s0)431b8cf7af9Swren romano LatSetId Merger::optimizeSet(LatSetId s0) {
43246a384dfSwren romano   const LatSetId sNew = addSet();
43346a384dfSwren romano   auto &setNew = latSets[sNew];
43446a384dfSwren romano   const auto &set0 = set(s0);
43546a384dfSwren romano   assert(!set0.empty());
43646a384dfSwren romano   const LatPointId p0 = set0[0];
43746a384dfSwren romano   for (const LatPointId p1 : set0) {
438744146f6SGus Smith     bool add = true;
439744146f6SGus Smith     if (p0 != p1) {
440b8cf7af9Swren romano       // Check whether this is a straightforward copy.
44146a384dfSwren romano       if (expIsTensor(latPoints[p1].exp, outTensor))
442744146f6SGus Smith         continue;
443b8cf7af9Swren romano       // Check whether this conjunction is already covered.
44446a384dfSwren romano       for (const LatPointId p2 : setNew) {
445744146f6SGus Smith         assert(!latGT(p1, p2)); // Lj => Li would be bad
446744146f6SGus Smith         if (onlyDenseDiff(p2, p1)) {
447744146f6SGus Smith           add = false;
448744146f6SGus Smith           break;
449744146f6SGus Smith         }
450744146f6SGus Smith       }
451744146f6SGus Smith       assert(!add || latGT(p0, p1));
452744146f6SGus Smith     }
453744146f6SGus Smith     if (add)
45446a384dfSwren romano       setNew.push_back(p1);
455744146f6SGus Smith   }
45646a384dfSwren romano   for (const LatPointId p : setNew)
45746a384dfSwren romano     latPoints[p].simple = simplifyCond(sNew, p);
45846a384dfSwren romano   return sNew;
459744146f6SGus Smith }
460744146f6SGus Smith 
simplifyCond(LatSetId s0,LatPointId p0)461b8cf7af9Swren romano BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
462744146f6SGus Smith   // First determine if this lattice point is a *singleton*, i.e.,
463744146f6SGus Smith   // the last point in a lattice, no other is less than this one.
464744146f6SGus Smith   bool isSingleton = true;
46546a384dfSwren romano   for (const LatPointId p1 : set(s0)) {
466744146f6SGus Smith     if (p0 != p1 && latGT(p0, p1)) {
467744146f6SGus Smith       isSingleton = false;
468744146f6SGus Smith       break;
469744146f6SGus Smith     }
470744146f6SGus Smith   }
47101dffc5aSPeiming Liu 
4725fd9d801SPeiming Liu   BitVector simple(latPoints[p0].bits);
4735fd9d801SPeiming Liu   bool reset = isSingleton && hasAnySparse(simple);
4745fd9d801SPeiming Liu   const TensorLoopId be = simple.size();
4755fd9d801SPeiming Liu   TensorLoopId offset = 0; // relative to the end
47601dffc5aSPeiming Liu   if (!reset)
477b8cf7af9Swren romano     // Starts resetting from a dense level, so that the first bit (if kept)
478b8cf7af9Swren romano     // is not undefined level-type.
47946a384dfSwren romano     for (unsigned b = 0; b < be; b++) {
480d82e93e7SPeiming Liu       if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
481d30dccd2SPeiming Liu         offset = be - b - 1; // relative to the end
482d30dccd2SPeiming Liu         break;
483d30dccd2SPeiming Liu       }
484d30dccd2SPeiming Liu     }
48501dffc5aSPeiming Liu 
486d30dccd2SPeiming Liu   // Now apply the two basic rules. We also iterate the bits reversely to always
487d30dccd2SPeiming Liu   // keep the rightmost bit (which could possibly be a synthetic tensor).
48846a384dfSwren romano   for (unsigned b = be - 1 - offset, i = 0; i < be;
489d30dccd2SPeiming Liu        b = b == 0 ? be - 1 : b - 1, i++) {
4905fd9d801SPeiming Liu     // Slice on dense level has `locate` property as well, and can be optimized.
4915fd9d801SPeiming Liu     if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
4921dd387e1SAart Bik       const auto lt = getLvlType(b);
493d82e93e7SPeiming Liu       if (!lt.hasSparseSemantic()) {
494744146f6SGus Smith         if (reset)
495d30dccd2SPeiming Liu           simple.reset(b);
496744146f6SGus Smith         reset = true;
497744146f6SGus Smith       }
498744146f6SGus Smith     }
499b8cf7af9Swren romano   }
500744146f6SGus Smith   return simple;
501744146f6SGus Smith }
502744146f6SGus Smith 
latGT(LatPointId i,LatPointId j) const503b8cf7af9Swren romano bool Merger::latGT(LatPointId i, LatPointId j) const {
50446a384dfSwren romano   const BitVector &bitsi = lat(i).bits;
50546a384dfSwren romano   const BitVector &bitsj = lat(j).bits;
506744146f6SGus Smith   assert(bitsi.size() == bitsj.size());
507744146f6SGus Smith   if (bitsi.count() > bitsj.count()) {
508b8cf7af9Swren romano     for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
509744146f6SGus Smith       if (bitsj[b] && !bitsi[b])
510744146f6SGus Smith         return false;
511744146f6SGus Smith     return true;
512744146f6SGus Smith   }
513744146f6SGus Smith   return false;
514744146f6SGus Smith }
515744146f6SGus Smith 
onlyDenseDiff(LatPointId i,LatPointId j) const516b8cf7af9Swren romano bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
5175fd9d801SPeiming Liu   BitVector tmp(latPoints[j].bits);
5185fd9d801SPeiming Liu   tmp ^= latPoints[i].bits;
5195fd9d801SPeiming Liu   return !hasAnySparse(tmp);
520744146f6SGus Smith }
521744146f6SGus Smith 
expContainsTensor(ExprId e,TensorId t) const522b8cf7af9Swren romano bool Merger::expContainsTensor(ExprId e, TensorId t) const {
52346a384dfSwren romano   const auto &expr = exp(e);
524686ef4b4Swren romano   // First we check `expIsTensor`.
52546a384dfSwren romano   if (expr.kind == TensorExp::Kind::kTensor)
52646a384dfSwren romano     return expr.tensor == t;
527faa75f94SPeiming Liu 
52846a384dfSwren romano   switch (getExpArity(expr.kind)) {
529faa75f94SPeiming Liu   case ExpArity::kNullary:
530faa75f94SPeiming Liu     return false;
531faa75f94SPeiming Liu   case ExpArity::kUnary: {
53246a384dfSwren romano     const ExprId e0 = expr.children.e0;
533b8cf7af9Swren romano     return expContainsTensor(e0, t);
534faa75f94SPeiming Liu   }
535faa75f94SPeiming Liu   case ExpArity::kBinary: {
53646a384dfSwren romano     const ExprId e0 = expr.children.e0;
53746a384dfSwren romano     const ExprId e1 = expr.children.e1;
538b8cf7af9Swren romano     return expContainsTensor(e0, t) || expContainsTensor(e1, t);
539faa75f94SPeiming Liu   }
540faa75f94SPeiming Liu   }
541faa75f94SPeiming Liu   llvm_unreachable("unexpected arity");
542faa75f94SPeiming Liu }
543faa75f94SPeiming Liu 
hasNegateOnOut(ExprId e) const544b8cf7af9Swren romano bool Merger::hasNegateOnOut(ExprId e) const {
54546a384dfSwren romano   const auto &expr = exp(e);
54646a384dfSwren romano   switch (expr.kind) {
5471f58ae80Swren romano   case TensorExp::Kind::kNegF:
5481f58ae80Swren romano   case TensorExp::Kind::kNegC:
5491f58ae80Swren romano   case TensorExp::Kind::kNegI:
55046a384dfSwren romano     return expContainsTensor(expr.children.e0, outTensor);
5511f58ae80Swren romano   case TensorExp::Kind::kSubF:
5521f58ae80Swren romano   case TensorExp::Kind::kSubC:
5531f58ae80Swren romano   case TensorExp::Kind::kSubI:
55446a384dfSwren romano     return expContainsTensor(expr.children.e1, outTensor) ||
55546a384dfSwren romano            hasNegateOnOut(expr.children.e0);
556df11a2b4SPeiming Liu   case TensorExp::Kind::kDenseOp: {
557df11a2b4SPeiming Liu     bool lhsNeg = hasNegateOnOut(expr.children.e0);
558df11a2b4SPeiming Liu     if (!lhsNeg && expr.children.e1 != detail::kInvalidId)
559df11a2b4SPeiming Liu       return hasNegateOnOut(expr.children.e1);
560df11a2b4SPeiming Liu     return lhsNeg;
561df11a2b4SPeiming Liu   }
562faa75f94SPeiming Liu   default: {
56346a384dfSwren romano     switch (getExpArity(expr.kind)) {
564faa75f94SPeiming Liu     case ExpArity::kNullary:
565faa75f94SPeiming Liu       return false;
566faa75f94SPeiming Liu     case ExpArity::kUnary:
56746a384dfSwren romano       return hasNegateOnOut(expr.children.e0);
568faa75f94SPeiming Liu     case ExpArity::kBinary:
56946a384dfSwren romano       return hasNegateOnOut(expr.children.e0) ||
57046a384dfSwren romano              hasNegateOnOut(expr.children.e1);
571faa75f94SPeiming Liu     }
572faa75f94SPeiming Liu   }
573faa75f94SPeiming Liu   }
574faa75f94SPeiming Liu   llvm_unreachable("unexpected kind");
575faa75f94SPeiming Liu }
576faa75f94SPeiming Liu 
isSingleCondition(TensorId t,ExprId e) const577b8cf7af9Swren romano bool Merger::isSingleCondition(TensorId t, ExprId e) const {
57846a384dfSwren romano   assert(isValidTensorId(t));
57946a384dfSwren romano   const auto &expr = exp(e);
58046a384dfSwren romano   switch (expr.kind) {
58106aa6ec8SAart Bik   // Leaf.
5821f58ae80Swren romano   case TensorExp::Kind::kTensor:
58346a384dfSwren romano     return expr.tensor == t;
5841f58ae80Swren romano   case TensorExp::Kind::kInvariant:
5851f58ae80Swren romano   case TensorExp::Kind::kLoopVar:
586faf7cd97SPeiming Liu   case TensorExp::Kind::kSynZero:
58706aa6ec8SAart Bik     return false;
58806aa6ec8SAart Bik   // Unary operations.
5891f58ae80Swren romano   case TensorExp::Kind::kAbsF:
5901f58ae80Swren romano   case TensorExp::Kind::kAbsC:
5911f58ae80Swren romano   case TensorExp::Kind::kAbsI:
5921f58ae80Swren romano   case TensorExp::Kind::kCeilF:
5931f58ae80Swren romano   case TensorExp::Kind::kFloorF:
5941f58ae80Swren romano   case TensorExp::Kind::kSqrtF:
5951f58ae80Swren romano   case TensorExp::Kind::kSqrtC:
5961f58ae80Swren romano   case TensorExp::Kind::kExpm1F:
5971f58ae80Swren romano   case TensorExp::Kind::kExpm1C:
5981f58ae80Swren romano   case TensorExp::Kind::kLog1pF:
5991f58ae80Swren romano   case TensorExp::Kind::kLog1pC:
600*70e227a4SAart Bik   case TensorExp::Kind::kRelu:
6011f58ae80Swren romano   case TensorExp::Kind::kSinF:
6021f58ae80Swren romano   case TensorExp::Kind::kSinC:
6031f58ae80Swren romano   case TensorExp::Kind::kTanhF:
6041f58ae80Swren romano   case TensorExp::Kind::kTanhC:
6051f58ae80Swren romano   case TensorExp::Kind::kNegF:
6061f58ae80Swren romano   case TensorExp::Kind::kNegC:
6071f58ae80Swren romano   case TensorExp::Kind::kNegI:
6081f58ae80Swren romano   case TensorExp::Kind::kTruncF:
6091f58ae80Swren romano   case TensorExp::Kind::kExtF:
6101f58ae80Swren romano   case TensorExp::Kind::kCastFS:
6111f58ae80Swren romano   case TensorExp::Kind::kCastFU:
6121f58ae80Swren romano   case TensorExp::Kind::kCastSF:
6131f58ae80Swren romano   case TensorExp::Kind::kCastUF:
6141f58ae80Swren romano   case TensorExp::Kind::kCastS:
6151f58ae80Swren romano   case TensorExp::Kind::kCastU:
6161f58ae80Swren romano   case TensorExp::Kind::kCastIdx:
6171f58ae80Swren romano   case TensorExp::Kind::kTruncI:
6181f58ae80Swren romano   case TensorExp::Kind::kCIm:
6191f58ae80Swren romano   case TensorExp::Kind::kCRe:
6201f58ae80Swren romano   case TensorExp::Kind::kBitCast:
621debdf7e0SAart Bik   case TensorExp::Kind::kUnary:
62246a384dfSwren romano     return isSingleCondition(t, expr.children.e0);
6231f58ae80Swren romano   case TensorExp::Kind::kBinaryBranch:
6241f58ae80Swren romano   case TensorExp::Kind::kSelect:
62506aa6ec8SAart Bik     return false;
62606aa6ec8SAart Bik   // Binary operations.
6271f58ae80Swren romano   case TensorExp::Kind::kDivF: // note: x / c only
6281f58ae80Swren romano   case TensorExp::Kind::kDivC:
6291f58ae80Swren romano   case TensorExp::Kind::kDivS:
6301f58ae80Swren romano   case TensorExp::Kind::kDivU:
63146a384dfSwren romano     assert(!maybeZero(expr.children.e1));
63246a384dfSwren romano     return isSingleCondition(t, expr.children.e0);
6331f58ae80Swren romano   case TensorExp::Kind::kShrS: // note: x >> inv only
6341f58ae80Swren romano   case TensorExp::Kind::kShrU:
6351f58ae80Swren romano   case TensorExp::Kind::kShlI:
63646a384dfSwren romano     assert(isInvariant(expr.children.e1));
63746a384dfSwren romano     return isSingleCondition(t, expr.children.e0);
6381f58ae80Swren romano   case TensorExp::Kind::kMulF:
6391f58ae80Swren romano   case TensorExp::Kind::kMulC:
6401f58ae80Swren romano   case TensorExp::Kind::kMulI:
6411f58ae80Swren romano   case TensorExp::Kind::kAndI:
642debdf7e0SAart Bik   case TensorExp::Kind::kReduce:
64346a384dfSwren romano     if (isSingleCondition(t, expr.children.e0))
64446a384dfSwren romano       return isSingleCondition(t, expr.children.e1) ||
64546a384dfSwren romano              isInvariant(expr.children.e1);
64646a384dfSwren romano     if (isSingleCondition(t, expr.children.e1))
64746a384dfSwren romano       return isInvariant(expr.children.e0);
6480e85232fSAart Bik     return false;
6491f58ae80Swren romano   case TensorExp::Kind::kAddF:
6501f58ae80Swren romano   case TensorExp::Kind::kAddC:
6511f58ae80Swren romano   case TensorExp::Kind::kAddI:
65246a384dfSwren romano     return isSingleCondition(t, expr.children.e0) &&
65346a384dfSwren romano            isSingleCondition(t, expr.children.e1);
6541f58ae80Swren romano   case TensorExp::Kind::kSubF:
6551f58ae80Swren romano   case TensorExp::Kind::kSubC:
6561f58ae80Swren romano   case TensorExp::Kind::kSubI:
6571f58ae80Swren romano   case TensorExp::Kind::kOrI:
6581f58ae80Swren romano   case TensorExp::Kind::kXorI:
659faf7cd97SPeiming Liu   case TensorExp::Kind::kCmpF:
660faf7cd97SPeiming Liu   case TensorExp::Kind::kCmpI:
6611f58ae80Swren romano   case TensorExp::Kind::kBinary:
66245b3cfe8SAart Bik     return false;
663df11a2b4SPeiming Liu   case TensorExp::Kind::kDenseOp:
664df11a2b4SPeiming Liu     // Since Merger guarantees all the operands of the kDenseOp to be dense, the
665df11a2b4SPeiming Liu     // operation must be single-condition.
666df11a2b4SPeiming Liu     return true;
66745b3cfe8SAart Bik   }
668f8ec4dfaSMogball   llvm_unreachable("unexpected kind");
66945b3cfe8SAart Bik }
67045b3cfe8SAart Bik 
hasAnySparse(const BitVector & bits) const67147a715d4SAart Bik bool Merger::hasAnySparse(const BitVector &bits) const {
6725fd9d801SPeiming Liu   for (TensorLoopId b : bits.set_bits()) {
6731dd387e1SAart Bik     const auto lt = getLvlType(b);
674d82e93e7SPeiming Liu     if (lt.hasSparseSemantic())
675b22397feSAart Bik       return true;
676b8cf7af9Swren romano   }
6775fd9d801SPeiming Liu   return hasSparseIdxReduction(bits);
678b22397feSAart Bik }
679b22397feSAart Bik 
hasSparseIdxReduction(const BitVector & bits) const6801328bb6eSPeiming Liu bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
6815fd9d801SPeiming Liu   for (TensorLoopId b : bits.set_bits())
6825fd9d801SPeiming Liu     if (isSparseLvlWithNonTrivialIdxExp(b))
6831328bb6eSPeiming Liu       return true;
6841328bb6eSPeiming Liu   return false;
6851328bb6eSPeiming Liu }
6861328bb6eSPeiming Liu 
687557b101cSAart Bik #ifndef NDEBUG
688557b101cSAart Bik 
689e2d3db42SAart Bik //===----------------------------------------------------------------------===//
690557b101cSAart Bik // Print methods (for debugging).
691e2d3db42SAart Bik //===----------------------------------------------------------------------===//
692557b101cSAart Bik 
kindToOpSymbol(TensorExp::Kind kind)6931f58ae80Swren romano static const char *kindToOpSymbol(TensorExp::Kind kind) {
6948fe65972SAart Bik   switch (kind) {
69506aa6ec8SAart Bik   // Leaf.
6961f58ae80Swren romano   case TensorExp::Kind::kTensor:
6978fe65972SAart Bik     return "tensor";
6981f58ae80Swren romano   case TensorExp::Kind::kInvariant:
6998fe65972SAart Bik     return "invariant";
7001f58ae80Swren romano   case TensorExp::Kind::kLoopVar:
70153cc3a06SAart Bik     return "index";
702faf7cd97SPeiming Liu   case TensorExp::Kind::kSynZero:
703faf7cd97SPeiming Liu     return "0";
70406aa6ec8SAart Bik   // Unary operations.
7051f58ae80Swren romano   case TensorExp::Kind::kAbsF:
7061f58ae80Swren romano   case TensorExp::Kind::kAbsC:
7071f58ae80Swren romano   case TensorExp::Kind::kAbsI:
7088fe65972SAart Bik     return "abs";
7091f58ae80Swren romano   case TensorExp::Kind::kCeilF:
7108fe65972SAart Bik     return "ceil";
7111f58ae80Swren romano   case TensorExp::Kind::kFloorF:
7128fe65972SAart Bik     return "floor";
7131f58ae80Swren romano   case TensorExp::Kind::kSqrtF:
7141f58ae80Swren romano   case TensorExp::Kind::kSqrtC:
715952fa301SAart Bik     return "sqrt";
7161f58ae80Swren romano   case TensorExp::Kind::kExpm1F:
7171f58ae80Swren romano   case TensorExp::Kind::kExpm1C:
718952fa301SAart Bik     return "expm1";
7191f58ae80Swren romano   case TensorExp::Kind::kLog1pF:
7201f58ae80Swren romano   case TensorExp::Kind::kLog1pC:
721952fa301SAart Bik     return "log1p";
722*70e227a4SAart Bik   case TensorExp::Kind::kRelu:
723*70e227a4SAart Bik     return "relu";
7241f58ae80Swren romano   case TensorExp::Kind::kSinF:
7251f58ae80Swren romano   case TensorExp::Kind::kSinC:
726952fa301SAart Bik     return "sin";
7271f58ae80Swren romano   case TensorExp::Kind::kTanhF:
7281f58ae80Swren romano   case TensorExp::Kind::kTanhC:
729952fa301SAart Bik     return "tanh";
7301f58ae80Swren romano   case TensorExp::Kind::kNegF:
7311f58ae80Swren romano   case TensorExp::Kind::kNegC:
7321f58ae80Swren romano   case TensorExp::Kind::kNegI:
7338fe65972SAart Bik     return "-";
7341f58ae80Swren romano   case TensorExp::Kind::kTruncF:
7351f58ae80Swren romano   case TensorExp::Kind::kExtF:
7361f58ae80Swren romano   case TensorExp::Kind::kCastFS:
7371f58ae80Swren romano   case TensorExp::Kind::kCastFU:
7381f58ae80Swren romano   case TensorExp::Kind::kCastSF:
7391f58ae80Swren romano   case TensorExp::Kind::kCastUF:
7401f58ae80Swren romano   case TensorExp::Kind::kCastS:
7411f58ae80Swren romano   case TensorExp::Kind::kCastU:
7421f58ae80Swren romano   case TensorExp::Kind::kCastIdx:
7431f58ae80Swren romano   case TensorExp::Kind::kTruncI:
7441f58ae80Swren romano   case TensorExp::Kind::kCIm:
74569edacbcSBixia Zheng     return "complex.im";
7461f58ae80Swren romano   case TensorExp::Kind::kCRe:
74769edacbcSBixia Zheng     return "complex.re";
7481f58ae80Swren romano   case TensorExp::Kind::kBitCast:
749e2d3db42SAart Bik     return "cast";
7501f58ae80Swren romano   case TensorExp::Kind::kBinaryBranch:
7512c332660SJim Kitchen     return "binary_branch";
7521f58ae80Swren romano   case TensorExp::Kind::kUnary:
7532c332660SJim Kitchen     return "unary";
7541f58ae80Swren romano   case TensorExp::Kind::kSelect:
75579193503SJim Kitchen     return "select";
75606aa6ec8SAart Bik   // Binary operations.
7571f58ae80Swren romano   case TensorExp::Kind::kMulF:
7581f58ae80Swren romano   case TensorExp::Kind::kMulC:
7591f58ae80Swren romano   case TensorExp::Kind::kMulI:
7608fe65972SAart Bik     return "*";
7611f58ae80Swren romano   case TensorExp::Kind::kDivF:
7621f58ae80Swren romano   case TensorExp::Kind::kDivC:
7631f58ae80Swren romano   case TensorExp::Kind::kDivS:
7641f58ae80Swren romano   case TensorExp::Kind::kDivU:
7658fe65972SAart Bik     return "/";
7661f58ae80Swren romano   case TensorExp::Kind::kAddF:
7671f58ae80Swren romano   case TensorExp::Kind::kAddC:
7681f58ae80Swren romano   case TensorExp::Kind::kAddI:
7698fe65972SAart Bik     return "+";
7701f58ae80Swren romano   case TensorExp::Kind::kSubF:
7711f58ae80Swren romano   case TensorExp::Kind::kSubC:
7721f58ae80Swren romano   case TensorExp::Kind::kSubI:
7738fe65972SAart Bik     return "-";
7741f58ae80Swren romano   case TensorExp::Kind::kAndI:
7758fe65972SAart Bik     return "&";
7761f58ae80Swren romano   case TensorExp::Kind::kOrI:
7778fe65972SAart Bik     return "|";
7781f58ae80Swren romano   case TensorExp::Kind::kXorI:
7798fe65972SAart Bik     return "^";
7801f58ae80Swren romano   case TensorExp::Kind::kShrS:
7818fe65972SAart Bik     return "a>>";
7821f58ae80Swren romano   case TensorExp::Kind::kShrU:
7838fe65972SAart Bik     return ">>";
7841f58ae80Swren romano   case TensorExp::Kind::kShlI:
7858fe65972SAart Bik     return "<<";
786faf7cd97SPeiming Liu   case TensorExp::Kind::kCmpF:
787faf7cd97SPeiming Liu   case TensorExp::Kind::kCmpI:
788faf7cd97SPeiming Liu     return "cmp";
7891f58ae80Swren romano   case TensorExp::Kind::kBinary:
7902c332660SJim Kitchen     return "binary";
7911f58ae80Swren romano   case TensorExp::Kind::kReduce:
792c8bb2354SJim Kitchen     return "reduce";
793df11a2b4SPeiming Liu   case TensorExp::Kind::kDenseOp:
794df11a2b4SPeiming Liu     return "dense";
7958fe65972SAart Bik   }
7968fe65972SAart Bik   llvm_unreachable("unexpected kind for symbol");
7978fe65972SAart Bik }
798b8a021dbSAart Bik 
dumpExp(ExprId e) const799b8cf7af9Swren romano void Merger::dumpExp(ExprId e) const {
80046a384dfSwren romano   const auto &expr = exp(e);
80146a384dfSwren romano   switch (expr.kind) {
80206aa6ec8SAart Bik   // Leaf.
8031f58ae80Swren romano   case TensorExp::Kind::kTensor:
80446a384dfSwren romano     if (expr.tensor == syntheticTensor)
805266a7414SAart Bik       llvm::dbgs() << "synthetic_";
80646a384dfSwren romano     else if (expr.tensor == outTensor)
807266a7414SAart Bik       llvm::dbgs() << "output_";
80846a384dfSwren romano     llvm::dbgs() << "tensor_" << expr.tensor;
809557b101cSAart Bik     break;
8101f58ae80Swren romano   case TensorExp::Kind::kInvariant:
811557b101cSAart Bik     llvm::dbgs() << "invariant";
812557b101cSAart Bik     break;
813faf7cd97SPeiming Liu   case TensorExp::Kind::kSynZero:
814faf7cd97SPeiming Liu     llvm::dbgs() << "0";
815faf7cd97SPeiming Liu     break;
8161f58ae80Swren romano   case TensorExp::Kind::kLoopVar:
81746a384dfSwren romano     llvm::dbgs() << "loopvar_" << expr.loop;
81853cc3a06SAart Bik     break;
81906aa6ec8SAart Bik   // Unary operations.
8201f58ae80Swren romano   case TensorExp::Kind::kAbsF:
8211f58ae80Swren romano   case TensorExp::Kind::kAbsC:
8221f58ae80Swren romano   case TensorExp::Kind::kAbsI:
8231f58ae80Swren romano   case TensorExp::Kind::kCeilF:
8241f58ae80Swren romano   case TensorExp::Kind::kFloorF:
8251f58ae80Swren romano   case TensorExp::Kind::kSqrtF:
8261f58ae80Swren romano   case TensorExp::Kind::kSqrtC:
8271f58ae80Swren romano   case TensorExp::Kind::kExpm1F:
8281f58ae80Swren romano   case TensorExp::Kind::kExpm1C:
8291f58ae80Swren romano   case TensorExp::Kind::kLog1pF:
8301f58ae80Swren romano   case TensorExp::Kind::kLog1pC:
831*70e227a4SAart Bik   case TensorExp::Kind::kRelu:
8321f58ae80Swren romano   case TensorExp::Kind::kSinF:
8331f58ae80Swren romano   case TensorExp::Kind::kSinC:
8341f58ae80Swren romano   case TensorExp::Kind::kTanhF:
8351f58ae80Swren romano   case TensorExp::Kind::kTanhC:
8361f58ae80Swren romano   case TensorExp::Kind::kNegF:
8371f58ae80Swren romano   case TensorExp::Kind::kNegC:
8381f58ae80Swren romano   case TensorExp::Kind::kNegI:
8391f58ae80Swren romano   case TensorExp::Kind::kTruncF:
8401f58ae80Swren romano   case TensorExp::Kind::kExtF:
8411f58ae80Swren romano   case TensorExp::Kind::kCastFS:
8421f58ae80Swren romano   case TensorExp::Kind::kCastFU:
8431f58ae80Swren romano   case TensorExp::Kind::kCastSF:
8441f58ae80Swren romano   case TensorExp::Kind::kCastUF:
8451f58ae80Swren romano   case TensorExp::Kind::kCastS:
8461f58ae80Swren romano   case TensorExp::Kind::kCastU:
8471f58ae80Swren romano   case TensorExp::Kind::kCastIdx:
8481f58ae80Swren romano   case TensorExp::Kind::kTruncI:
8491f58ae80Swren romano   case TensorExp::Kind::kCIm:
8501f58ae80Swren romano   case TensorExp::Kind::kCRe:
8511f58ae80Swren romano   case TensorExp::Kind::kBitCast:
8521f58ae80Swren romano   case TensorExp::Kind::kBinaryBranch:
8531f58ae80Swren romano   case TensorExp::Kind::kUnary:
8541f58ae80Swren romano   case TensorExp::Kind::kSelect:
85546a384dfSwren romano     llvm::dbgs() << kindToOpSymbol(expr.kind) << " ";
85646a384dfSwren romano     dumpExp(expr.children.e0);
857b8a021dbSAart Bik     break;
85806aa6ec8SAart Bik   // Binary operations.
8591f58ae80Swren romano   case TensorExp::Kind::kMulF:
8601f58ae80Swren romano   case TensorExp::Kind::kMulC:
8611f58ae80Swren romano   case TensorExp::Kind::kMulI:
8621f58ae80Swren romano   case TensorExp::Kind::kDivF:
8631f58ae80Swren romano   case TensorExp::Kind::kDivC:
8641f58ae80Swren romano   case TensorExp::Kind::kDivS:
8651f58ae80Swren romano   case TensorExp::Kind::kDivU:
8661f58ae80Swren romano   case TensorExp::Kind::kAddF:
8671f58ae80Swren romano   case TensorExp::Kind::kAddC:
8681f58ae80Swren romano   case TensorExp::Kind::kAddI:
8691f58ae80Swren romano   case TensorExp::Kind::kSubF:
8701f58ae80Swren romano   case TensorExp::Kind::kSubC:
8711f58ae80Swren romano   case TensorExp::Kind::kSubI:
8721f58ae80Swren romano   case TensorExp::Kind::kAndI:
8731f58ae80Swren romano   case TensorExp::Kind::kOrI:
8741f58ae80Swren romano   case TensorExp::Kind::kXorI:
8751f58ae80Swren romano   case TensorExp::Kind::kShrS:
8761f58ae80Swren romano   case TensorExp::Kind::kShrU:
8771f58ae80Swren romano   case TensorExp::Kind::kShlI:
878faf7cd97SPeiming Liu   case TensorExp::Kind::kCmpF:
879faf7cd97SPeiming Liu   case TensorExp::Kind::kCmpI:
8801f58ae80Swren romano   case TensorExp::Kind::kBinary:
8811f58ae80Swren romano   case TensorExp::Kind::kReduce:
882df11a2b4SPeiming Liu   case TensorExp::Kind::kDenseOp:
883557b101cSAart Bik     llvm::dbgs() << "(";
88446a384dfSwren romano     dumpExp(expr.children.e0);
885faf7cd97SPeiming Liu     llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
886faf7cd97SPeiming Liu     if (expr.attr)
887faf7cd97SPeiming Liu       llvm::dbgs() << "{" << expr.attr << "}";
888df11a2b4SPeiming Liu     if (expr.children.e1 != detail::kInvalidId) {
889faf7cd97SPeiming Liu       llvm::dbgs() << " ";
89046a384dfSwren romano       dumpExp(expr.children.e1);
891557b101cSAart Bik       llvm::dbgs() << ")";
892df11a2b4SPeiming Liu     } else {
893df11a2b4SPeiming Liu       assert(expr.kind == TensorExp::Kind::kDenseOp);
894df11a2b4SPeiming Liu     }
895debdf7e0SAart Bik     break;
896557b101cSAart Bik   }
897557b101cSAart Bik }
898557b101cSAart Bik 
dumpLat(LatPointId p) const899b8cf7af9Swren romano void Merger::dumpLat(LatPointId p) const {
90046a384dfSwren romano   const auto &point = lat(p);
901557b101cSAart Bik   llvm::dbgs() << "lat(";
90246a384dfSwren romano   dumpBits(point.bits);
903557b101cSAart Bik   llvm::dbgs() << " :";
90446a384dfSwren romano   dumpBits(point.simple);
905b8a021dbSAart Bik   llvm::dbgs() << " : ";
90646a384dfSwren romano   dumpExp(point.exp);
907557b101cSAart Bik   llvm::dbgs() << " )\n";
908557b101cSAart Bik }
909557b101cSAart Bik 
dumpSet(LatSetId s) const910b8cf7af9Swren romano void Merger::dumpSet(LatSetId s) const {
91146a384dfSwren romano   const auto &ss = set(s);
91246a384dfSwren romano   llvm::dbgs() << "{ #" << ss.size() << "\n";
91346a384dfSwren romano   for (const LatPointId p : ss) {
914557b101cSAart Bik     llvm::dbgs() << "  ";
915557b101cSAart Bik     dumpLat(p);
916557b101cSAart Bik   }
917557b101cSAart Bik   llvm::dbgs() << "}\n";
918557b101cSAart Bik }
919557b101cSAart Bik 
dumpBits(const BitVector & bits) const920d10d49dcSRiver Riddle void Merger::dumpBits(const BitVector &bits) const {
921b8cf7af9Swren romano   for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
922557b101cSAart Bik     if (bits[b]) {
923b8cf7af9Swren romano       const TensorId t = tensor(b);
924b8cf7af9Swren romano       const LoopId i = loop(b);
9251dd387e1SAart Bik       const auto lt = lvlTypes[t][i];
926d03805f2SPeiming Liu       if (isLvlWithNonTrivialIdxExp(b))
927d03805f2SPeiming Liu         llvm::dbgs() << " DEP_" << t << "_" << i;
928d03805f2SPeiming Liu       else
9291dd387e1SAart Bik         llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt);
930557b101cSAart Bik     }
931557b101cSAart Bik   }
932557b101cSAart Bik }
933557b101cSAart Bik 
934557b101cSAart Bik #endif // NDEBUG
935557b101cSAart Bik 
936e2d3db42SAart Bik //===----------------------------------------------------------------------===//
937266a7414SAart Bik // Builder methods.
938e2d3db42SAart Bik //===----------------------------------------------------------------------===//
939266a7414SAart Bik 
buildLattices(ExprId e,LoopId i)940b8cf7af9Swren romano LatSetId Merger::buildLattices(ExprId e, LoopId i) {
94146a384dfSwren romano   // NOTE: The `expr` reference will be invalidated by recursive calls
94246a384dfSwren romano   // (and any other method that may add new expressions); therefore, the
94346a384dfSwren romano   // code below must make sure to copy fields of `expr` into local variables
94446a384dfSwren romano   // before making any recursive calls.
94546a384dfSwren romano   const auto &expr = exp(e);
94646a384dfSwren romano   const TensorExp::Kind kind = expr.kind;
947b8a021dbSAart Bik   switch (kind) {
94806aa6ec8SAart Bik   // Leaf.
9491f58ae80Swren romano   case TensorExp::Kind::kTensor:
9501f58ae80Swren romano   case TensorExp::Kind::kInvariant:
951faf7cd97SPeiming Liu   case TensorExp::Kind::kSynZero:
9521f58ae80Swren romano   case TensorExp::Kind::kLoopVar: {
953b8cf7af9Swren romano     // Either the loop-var is really used in the tensor expression, or it is
954b8cf7af9Swren romano     // set to the undefined loop-var in that level. An invariant expression,
95553cc3a06SAart Bik     // a proper index value, and a truly dynamic sparse output tensor are set
95653cc3a06SAart Bik     // to a synthetic tensor with undefined indices only to ensure the
95753cc3a06SAart Bik     // iteration space is not skipped as a result of their contents.
958b8cf7af9Swren romano     const LatSetId s = addSet();
959b8cf7af9Swren romano     TensorId t = syntheticTensor;
9601f58ae80Swren romano     if (kind == TensorExp::Kind::kTensor) {
96146a384dfSwren romano       t = expr.tensor;
9627d4da4e1SAart Bik       if (hasSparseOut && t == outTensor)
9637d4da4e1SAart Bik         t = syntheticTensor;
96453cc3a06SAart Bik     }
96545b3cfe8SAart Bik     latSets[s].push_back(addLat(t, i, e));
966266a7414SAart Bik     return s;
967266a7414SAart Bik   }
96806aa6ec8SAart Bik   // Unary operations.
9691f58ae80Swren romano   case TensorExp::Kind::kAbsF:
9701f58ae80Swren romano   case TensorExp::Kind::kAbsC:
9711f58ae80Swren romano   case TensorExp::Kind::kAbsI:
9721f58ae80Swren romano   case TensorExp::Kind::kCeilF:
9731f58ae80Swren romano   case TensorExp::Kind::kFloorF:
9741f58ae80Swren romano   case TensorExp::Kind::kSqrtF:
9751f58ae80Swren romano   case TensorExp::Kind::kSqrtC:
9761f58ae80Swren romano   case TensorExp::Kind::kExpm1F:
9771f58ae80Swren romano   case TensorExp::Kind::kExpm1C:
9781f58ae80Swren romano   case TensorExp::Kind::kLog1pF:
9791f58ae80Swren romano   case TensorExp::Kind::kLog1pC:
980*70e227a4SAart Bik   case TensorExp::Kind::kRelu:
9811f58ae80Swren romano   case TensorExp::Kind::kSinF:
9821f58ae80Swren romano   case TensorExp::Kind::kSinC:
9831f58ae80Swren romano   case TensorExp::Kind::kTanhF:
9841f58ae80Swren romano   case TensorExp::Kind::kTanhC:
9851f58ae80Swren romano   case TensorExp::Kind::kNegF:
9861f58ae80Swren romano   case TensorExp::Kind::kNegC:
9871f58ae80Swren romano   case TensorExp::Kind::kNegI:
9881f58ae80Swren romano   case TensorExp::Kind::kTruncF:
9891f58ae80Swren romano   case TensorExp::Kind::kExtF:
9901f58ae80Swren romano   case TensorExp::Kind::kCastFS:
9911f58ae80Swren romano   case TensorExp::Kind::kCastFU:
9921f58ae80Swren romano   case TensorExp::Kind::kCastSF:
9931f58ae80Swren romano   case TensorExp::Kind::kCastUF:
9941f58ae80Swren romano   case TensorExp::Kind::kCastS:
9951f58ae80Swren romano   case TensorExp::Kind::kCastU:
9961f58ae80Swren romano   case TensorExp::Kind::kCastIdx:
9971f58ae80Swren romano   case TensorExp::Kind::kTruncI:
9981f58ae80Swren romano   case TensorExp::Kind::kCIm:
9991f58ae80Swren romano   case TensorExp::Kind::kCRe:
10001f58ae80Swren romano   case TensorExp::Kind::kBitCast:
1001123e8dfcSAart Bik     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
1002123e8dfcSAart Bik     // lattice set of the operand through the operator into a new set.
1003123e8dfcSAart Bik     //
1004123e8dfcSAart Bik     //  -y|!y | y |
1005123e8dfcSAart Bik     //  --+---+---+
1006123e8dfcSAart Bik     //    | 0 |-y |
100746a384dfSwren romano     {
100846a384dfSwren romano       const ExprId e0 = expr.children.e0;
100946a384dfSwren romano       const Value v = expr.val;
1010*70e227a4SAart Bik       Attribute a = expr.attr;
1011*70e227a4SAart Bik       return mapSet(kind, buildLattices(e0, i), v, nullptr, a);
101246a384dfSwren romano     }
10131f58ae80Swren romano   case TensorExp::Kind::kBinaryBranch:
10141f58ae80Swren romano   case TensorExp::Kind::kSelect:
10152c332660SJim Kitchen     // The left or right half of a binary operation which has already
10162c332660SJim Kitchen     // been split into separate operations for each region.
101746a384dfSwren romano     {
101846a384dfSwren romano       const ExprId e0 = expr.children.e0;
101946a384dfSwren romano       Operation *const op = expr.op;
102046a384dfSwren romano       return mapSet(kind, buildLattices(e0, i), Value(), op);
102146a384dfSwren romano     }
10221f58ae80Swren romano   case TensorExp::Kind::kUnary:
10232c332660SJim Kitchen     // A custom unary operation.
10242c332660SJim Kitchen     //
10252c332660SJim Kitchen     //  op y|    !y    |     y      |
10262c332660SJim Kitchen     //  ----+----------+------------+
10272c332660SJim Kitchen     //      | absent() | present(y) |
10282c332660SJim Kitchen     {
102946a384dfSwren romano       const ExprId e0 = expr.children.e0;
103046a384dfSwren romano       UnaryOp unop = cast<UnaryOp>(expr.op);
103146a384dfSwren romano       const LatSetId child0 = buildLattices(e0, i);
103204235d07SJacques Pienaar       Region &absentRegion = unop.getAbsentRegion();
10332c332660SJim Kitchen       if (absentRegion.empty()) {
10342c332660SJim Kitchen         // Simple mapping over existing values.
10352c332660SJim Kitchen         return mapSet(kind, child0, Value(), unop);
1036debdf7e0SAart Bik       }
1037debdf7e0SAart Bik       // Use a disjunction with `unop` on the left and the absent value as an
10382c332660SJim Kitchen       // invariant on the right.
10392c332660SJim Kitchen       Block &absentBlock = absentRegion.front();
10402c332660SJim Kitchen       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
1041a54930e6SPeiming Liu       const Value absentVal = absentYield.getSingleResult();
104246a384dfSwren romano       const ExprId rhs = addInvariantExp(absentVal);
1043faf7cd97SPeiming Liu       return disjSet(e, child0, buildLattices(rhs, i), unop);
10442c332660SJim Kitchen     }
104506aa6ec8SAart Bik   // Binary operations.
10461f58ae80Swren romano   case TensorExp::Kind::kMulF:
10471f58ae80Swren romano   case TensorExp::Kind::kMulC:
10481f58ae80Swren romano   case TensorExp::Kind::kMulI:
10491f58ae80Swren romano   case TensorExp::Kind::kAndI:
1050622eb169SAart Bik     // A multiplicative operation only needs to be performed
1051622eb169SAart Bik     // for the conjunction of sparse iteration spaces.
1052622eb169SAart Bik     //
1053622eb169SAart Bik     //  x*y|!y | y |
1054622eb169SAart Bik     //  ---+---+---+
1055622eb169SAart Bik     //  !x | 0 | 0 |
1056622eb169SAart Bik     //   x | 0 |x*y|
1057736c1b66SAart Bik     //
1058736c1b66SAart Bik     // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
105946a384dfSwren romano     {
106046a384dfSwren romano       const ExprId e0 = expr.children.e0;
106146a384dfSwren romano       const ExprId e1 = expr.children.e1;
1062faf7cd97SPeiming Liu       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
106346a384dfSwren romano     }
10641f58ae80Swren romano   case TensorExp::Kind::kDivF:
10651f58ae80Swren romano   case TensorExp::Kind::kDivC:
10661f58ae80Swren romano   case TensorExp::Kind::kDivS:
10671f58ae80Swren romano   case TensorExp::Kind::kDivU:
1068622eb169SAart Bik     // A division is tricky, since 0/0, 0/c, c/0 all have
1069622eb169SAart Bik     // specific outcomes for floating-point and integers.
1070622eb169SAart Bik     // Thus, we need to traverse the full iteration space.
1071622eb169SAart Bik     //
1072622eb169SAart Bik     //  x/y|!y | y |
1073622eb169SAart Bik     //  ---+---+---+
1074622eb169SAart Bik     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
1075622eb169SAart Bik     //   x |x/0|x/y|  INT: x/0=exception for any x
1076622eb169SAart Bik     //
1077622eb169SAart Bik     // TODO: for now we "fixed" this by only accepting x/c cases
1078622eb169SAart Bik     //       during expression building, so that the conjunction
1079622eb169SAart Bik     //       rules applies (viz. x/c = x*(1/c) as far as lattice
1080622eb169SAart Bik     //       construction is concerned).
108146a384dfSwren romano     {
108246a384dfSwren romano       const ExprId e0 = expr.children.e0;
108346a384dfSwren romano       const ExprId e1 = expr.children.e1;
108446a384dfSwren romano       assert(!maybeZero(e1));
1085faf7cd97SPeiming Liu       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
108646a384dfSwren romano     }
10871f58ae80Swren romano   case TensorExp::Kind::kAddF:
10881f58ae80Swren romano   case TensorExp::Kind::kAddC:
10891f58ae80Swren romano   case TensorExp::Kind::kAddI:
10901f58ae80Swren romano   case TensorExp::Kind::kSubF:
10911f58ae80Swren romano   case TensorExp::Kind::kSubC:
10921f58ae80Swren romano   case TensorExp::Kind::kSubI:
10931f58ae80Swren romano   case TensorExp::Kind::kOrI:
10941f58ae80Swren romano   case TensorExp::Kind::kXorI:
1095622eb169SAart Bik     // An additive operation needs to be performed
1096622eb169SAart Bik     // for the disjunction of sparse iteration spaces.
1097622eb169SAart Bik     //
1098622eb169SAart Bik     //  x+y|!y | y |    x-y|!y | y |
1099622eb169SAart Bik     //  ---+---+---+    ---+---+---+
1100622eb169SAart Bik     //  !x | 0 | y |    !x | 0 |-y |
1101622eb169SAart Bik     //   x | x |x+y|     x | x |x-y|
110246a384dfSwren romano     {
110346a384dfSwren romano       const ExprId e0 = expr.children.e0;
110446a384dfSwren romano       const ExprId e1 = expr.children.e1;
1105faf7cd97SPeiming Liu       return disjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1106faf7cd97SPeiming Liu     }
1107faf7cd97SPeiming Liu   case TensorExp::Kind::kCmpF:
1108faf7cd97SPeiming Liu   case TensorExp::Kind::kCmpI:
11092cb99df6SYinying Li     // A comparison operation needs to be performed
1110faf7cd97SPeiming Liu     // for the disjunction of sparse iteration spaces.
1111faf7cd97SPeiming Liu     //
1112faf7cd97SPeiming Liu     //   x < y |  !y   |   y   |
1113faf7cd97SPeiming Liu     //  -------+-------+-------+
1114faf7cd97SPeiming Liu     //     !x  |   0   | 0 < y |
1115faf7cd97SPeiming Liu     //      x  | x < 0 | x < y |
1116faf7cd97SPeiming Liu     {
1117faf7cd97SPeiming Liu       const ExprId e0 = expr.children.e0;
1118faf7cd97SPeiming Liu       const ExprId e1 = expr.children.e1;
1119faf7cd97SPeiming Liu       return disjSetWithZero(e, buildLattices(e0, i), buildLattices(e1, i));
112046a384dfSwren romano     }
11211f58ae80Swren romano   case TensorExp::Kind::kShrS:
11221f58ae80Swren romano   case TensorExp::Kind::kShrU:
11231f58ae80Swren romano   case TensorExp::Kind::kShlI:
11242b6e4332SAart Bik     // A shift operation by an invariant amount (viz. tensor expressions
11252b6e4332SAart Bik     // can only occur at the left-hand-side of the operator) can be handled
11262cb99df6SYinying Li     // with the conjunction rule.
112746a384dfSwren romano     {
112846a384dfSwren romano       const ExprId e0 = expr.children.e0;
112946a384dfSwren romano       const ExprId e1 = expr.children.e1;
113046a384dfSwren romano       assert(isInvariant(e1));
1131faf7cd97SPeiming Liu       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
113246a384dfSwren romano     }
11331f58ae80Swren romano   case TensorExp::Kind::kBinary:
11342c332660SJim Kitchen     // A custom binary operation.
11352c332660SJim Kitchen     //
11362c332660SJim Kitchen     //  x op y|   !y    |       y      |
11372c332660SJim Kitchen     //  ------+---------+--------------+
11382c332660SJim Kitchen     //    !x  |  empty  |   right(y)   |
11392c332660SJim Kitchen     //     x  | left(x) | overlap(x,y) |
11402c332660SJim Kitchen     {
114146a384dfSwren romano       const ExprId e0 = expr.children.e0;
114246a384dfSwren romano       const ExprId e1 = expr.children.e1;
114346a384dfSwren romano       BinaryOp binop = cast<BinaryOp>(expr.op);
114446a384dfSwren romano       const LatSetId child0 = buildLattices(e0, i);
114546a384dfSwren romano       const LatSetId child1 = buildLattices(e1, i);
114604235d07SJacques Pienaar       Region &leftRegion = binop.getLeftRegion();
114704235d07SJacques Pienaar       Region &rightRegion = binop.getRightRegion();
11482c332660SJim Kitchen       // Left Region.
11492c332660SJim Kitchen       Operation *leftYield = nullptr;
11502c332660SJim Kitchen       if (!leftRegion.empty()) {
11512c332660SJim Kitchen         Block &leftBlock = leftRegion.front();
11522c332660SJim Kitchen         leftYield = leftBlock.getTerminator();
11532c332660SJim Kitchen       }
11542c332660SJim Kitchen       // Right Region.
11552c332660SJim Kitchen       Operation *rightYield = nullptr;
11562c332660SJim Kitchen       if (!rightRegion.empty()) {
11572c332660SJim Kitchen         Block &rightBlock = rightRegion.front();
11582c332660SJim Kitchen         rightYield = rightBlock.getTerminator();
11592c332660SJim Kitchen       }
116004235d07SJacques Pienaar       bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
116104235d07SJacques Pienaar       bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
1162faf7cd97SPeiming Liu       return combiSet(e, child0, child1, binop, includeLeft,
1163faf7cd97SPeiming Liu                       TensorExp::Kind::kBinaryBranch, leftYield, includeRight,
1164faf7cd97SPeiming Liu                       TensorExp::Kind::kBinaryBranch, rightYield);
11652c332660SJim Kitchen     }
11661f58ae80Swren romano   case TensorExp::Kind::kReduce:
1167c8bb2354SJim Kitchen     // A custom reduce operation.
116846a384dfSwren romano     {
116946a384dfSwren romano       const ExprId e0 = expr.children.e0;
117046a384dfSwren romano       const ExprId e1 = expr.children.e1;
117146a384dfSwren romano       Operation *const op = expr.op;
1172faf7cd97SPeiming Liu       return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
117346a384dfSwren romano     }
1174df11a2b4SPeiming Liu   case TensorExp::Kind::kDenseOp: {
1175df11a2b4SPeiming Liu     // It does not really matter whether we use conjunctive/disjunctive set
1176df11a2b4SPeiming Liu     // here, as all the operands of kDenseOp must be dense, the disjunctive set
1177df11a2b4SPeiming Liu     // will be optimized into conjunctive set eventually.
1178df11a2b4SPeiming Liu     if (expr.children.e1 == detail::kInvalidId) {
1179df11a2b4SPeiming Liu       const ExprId e0 = expr.children.e0;
1180df11a2b4SPeiming Liu       Operation *const op = expr.op;
1181df11a2b4SPeiming Liu       return mapSet(kind, buildLattices(e0, i), Value(), op);
1182df11a2b4SPeiming Liu     }
1183df11a2b4SPeiming Liu 
1184df11a2b4SPeiming Liu     const ExprId e0 = expr.children.e0;
1185df11a2b4SPeiming Liu     const ExprId e1 = expr.children.e1;
1186df11a2b4SPeiming Liu     Operation *const op = expr.op;
1187df11a2b4SPeiming Liu     return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1188df11a2b4SPeiming Liu   }
1189266a7414SAart Bik   }
1190266a7414SAart Bik   llvm_unreachable("unexpected expression kind");
1191266a7414SAart Bik }
1192266a7414SAart Bik 
buildTensorExpFromLinalg(linalg::GenericOp op)1193b8cf7af9Swren romano std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
11942a288616SAart Bik   // Build the linalg semantics backward from yield.
1195d3b3f765SJacques Pienaar   Operation *yield = op.getRegion().front().getTerminator();
11962a288616SAart Bik   assert(isa<linalg::YieldOp>(yield));
1197df11a2b4SPeiming Liu   return buildTensorExp(op, yield->getOperand(0)).first;
1198266a7414SAart Bik }
1199266a7414SAart Bik 
1200*70e227a4SAart Bik /// Only returns true if we are certain this is a zero.
isCertainZero(Value val)1201*70e227a4SAart Bik static bool isCertainZero(Value val) {
1202*70e227a4SAart Bik   if (auto c = val.getDefiningOp<complex::ConstantOp>()) {
1203*70e227a4SAart Bik     ArrayAttr arrayAttr = c.getValue();
1204*70e227a4SAart Bik     return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1205*70e227a4SAart Bik            cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1206*70e227a4SAart Bik   }
1207*70e227a4SAart Bik   if (auto c = val.getDefiningOp<arith::ConstantIntOp>())
1208*70e227a4SAart Bik     return c.value() == 0;
1209*70e227a4SAart Bik   if (auto c = val.getDefiningOp<arith::ConstantFloatOp>())
1210*70e227a4SAart Bik     return c.value().isZero();
1211*70e227a4SAart Bik   return false;
1212*70e227a4SAart Bik }
1213*70e227a4SAart Bik 
121446e77b5dSAart Bik /// Only returns false if we are certain this is a nonzero.
maybeZero(ExprId e) const1215b8cf7af9Swren romano bool Merger::maybeZero(ExprId e) const {
121646a384dfSwren romano   const auto &expr = exp(e);
121746a384dfSwren romano   if (expr.kind == TensorExp::Kind::kInvariant) {
1218*70e227a4SAart Bik     // Note that this is different from isCertainZero() in a subtle
1219*70e227a4SAart Bik     // way by always returning true for non-constants.
122046a384dfSwren romano     if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1221d390035bSBixia Zheng       ArrayAttr arrayAttr = c.getValue();
12225550c821STres Popp       return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
12235550c821STres Popp              cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1224d390035bSBixia Zheng     }
122546a384dfSwren romano     if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
1226a54f4eaeSMogball       return c.value() == 0;
122746a384dfSwren romano     if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>())
1228a54f4eaeSMogball       return c.value().isZero();
1229622eb169SAart Bik   }
1230622eb169SAart Bik   return true;
1231622eb169SAart Bik }
1232622eb169SAart Bik 
inferType(ExprId e,Value src) const1233b8cf7af9Swren romano Type Merger::inferType(ExprId e, Value src) const {
1234e2d3db42SAart Bik   // Obtain the destination type from the cast node.
123546a384dfSwren romano   Type dtp = exp(e).val.getType();
1236e2d3db42SAart Bik   // Inspect source type. For vector types, apply the same
1237e2d3db42SAart Bik   // vectorization to the destination type.
12385550c821STres Popp   if (auto vtp = dyn_cast<VectorType>(src.getType()))
1239f22af204SAndrzej Warzynski     return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
1240e2d3db42SAart Bik   return dtp;
1241e2d3db42SAart Bik }
1242e2d3db42SAart Bik 
1243c43e6274STim Harvey /// Ensures that the sparsifier can generate code for expression.
isAdmissibleBranchExp(Operation * op,Block * block,Value v)12445c327422SAart Bik static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
1245a3610359SAart Bik   // Arguments are always admissible.
12465550c821STres Popp   if (isa<BlockArgument>(v))
12472a288616SAart Bik     return true;
12482a288616SAart Bik   // Accept index anywhere.
12492a288616SAart Bik   Operation *def = v.getDefiningOp();
12502a288616SAart Bik   if (isa<linalg::IndexOp>(def))
12512a288616SAart Bik     return true;
12522a288616SAart Bik   // Operation defined outside branch.
1253b22397feSAart Bik   if (def->getBlock() != block)
12542a288616SAart Bik     return def->getBlock() != op->getBlock(); // invariant?
12552a288616SAart Bik   // Operation defined within branch. Anything is accepted,
1256a3610359SAart Bik   // as long as all subexpressions are admissible.
12572a288616SAart Bik   for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
12585c327422SAart Bik     if (!isAdmissibleBranchExp(op, block, def->getOperand(i)))
12592a288616SAart Bik       return false;
12602a288616SAart Bik   return true;
12612a288616SAart Bik }
12622a288616SAart Bik 
1263c43e6274STim Harvey /// Ensures that the sparsifier can generate code for branch.
isAdmissibleBranch(Operation * op,Region & region)12645c327422SAart Bik static bool isAdmissibleBranch(Operation *op, Region &region) {
12652a288616SAart Bik   if (region.empty())
12662a288616SAart Bik     return true;
12672a288616SAart Bik   // Build the semi-ring branch semantics backward from yield.
12682a288616SAart Bik   Operation *yield = region.front().getTerminator();
12692a288616SAart Bik   assert(isa<YieldOp>(yield));
12705c327422SAart Bik   return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
12712a288616SAart Bik }
12722a288616SAart Bik 
1273*70e227a4SAart Bik // Recognizes a direct GT comparison.
isGreater(TensorExp::Kind kind,Attribute attr)1274*70e227a4SAart Bik static bool isGreater(TensorExp::Kind kind, Attribute attr) {
1275*70e227a4SAart Bik   if (kind == TensorExp::Kind::kCmpI) {
1276*70e227a4SAart Bik     auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr).getValue();
1277*70e227a4SAart Bik     return pred == arith::CmpIPredicate::ugt ||
1278*70e227a4SAart Bik            pred == arith::CmpIPredicate::sgt;
1279*70e227a4SAart Bik   }
1280*70e227a4SAart Bik   if (kind == TensorExp::Kind::kCmpF) {
1281*70e227a4SAart Bik     auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr).getValue();
1282*70e227a4SAart Bik     return pred == arith::CmpFPredicate::UGT ||
1283*70e227a4SAart Bik            pred == arith::CmpFPredicate::OGT;
1284*70e227a4SAart Bik   }
1285*70e227a4SAart Bik   return false;
1286*70e227a4SAart Bik }
1287*70e227a4SAart Bik 
1288df11a2b4SPeiming Liu std::pair<std::optional<ExprId>, bool>
buildTensorExp(linalg::GenericOp op,Value v)1289df11a2b4SPeiming Liu Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1290df11a2b4SPeiming Liu   // Recursion leaves.
12915550c821STres Popp   if (auto arg = dyn_cast<BlockArgument>(v)) {
129246a384dfSwren romano     const TensorId tid = makeTensorId(arg.getArgNumber());
1293266a7414SAart Bik     // Any argument of the generic op that is not marked as a scalar
1294266a7414SAart Bik     // argument is considered a tensor, indexed by the implicit loop
1295266a7414SAart Bik     // bounds. This includes rank-0 tensor arguments.
1296266a7414SAart Bik     if (arg.getOwner()->getParentOp() == op) {
129746a384dfSwren romano       OpOperand &t = op->getOpOperand(tid);
1298df11a2b4SPeiming Liu       bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr;
1299a7cccb9cSAlexander Belyaev       if (!op.isScalar(&t))
1300df11a2b4SPeiming Liu         return {addTensorExp(tid), hasSpDep};
1301a7cccb9cSAlexander Belyaev       v = t.get(); // get scalar value
1302266a7414SAart Bik     }
1303266a7414SAart Bik     // Any other argument (marked as scalar argument for the generic op
1304266a7414SAart Bik     // or belonging to an enveloping op) is considered invariant.
1305df11a2b4SPeiming Liu     return {addInvariantExp(v), /*hasSpDep=*/false};
1306266a7414SAart Bik   }
1307*70e227a4SAart Bik 
1308266a7414SAart Bik   // Something defined outside is invariant.
130945b3cfe8SAart Bik   Operation *def = v.getDefiningOp();
1310d3b3f765SJacques Pienaar   if (def->getBlock() != &op.getRegion().front())
1311df11a2b4SPeiming Liu     return {addInvariantExp(v), /*hasSpDep=*/false};
131253cc3a06SAart Bik   // Construct index operations.
131353cc3a06SAart Bik   if (def->getNumOperands() == 0) {
131453cc3a06SAart Bik     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
1315df11a2b4SPeiming Liu       return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false};
131653cc3a06SAart Bik   }
1317df11a2b4SPeiming Liu 
1318b8a021dbSAart Bik   // Construct unary operations if subexpression can be built.
1319b8a021dbSAart Bik   if (def->getNumOperands() == 1) {
1320df11a2b4SPeiming Liu     const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
1321491d2701SKazu Hirata     if (x.has_value()) {
1322b8cf7af9Swren romano       const ExprId e = *x;
132300f7096dSJeff Niu       if (isa<math::AbsFOp>(def))
1324df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep};
1325d390035bSBixia Zheng       if (isa<complex::AbsOp>(def))
1326df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep};
13278dd07e36SAart Bik       if (isa<math::AbsIOp>(def))
1328df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep};
1329a54f4eaeSMogball       if (isa<math::CeilOp>(def))
1330df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep};
1331a54f4eaeSMogball       if (isa<math::FloorOp>(def))
1332df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep};
1333952fa301SAart Bik       if (isa<math::SqrtOp>(def))
1334df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep};
1335a14057d4Sbixia1       if (isa<complex::SqrtOp>(def))
1336df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep};
1337952fa301SAart Bik       if (isa<math::ExpM1Op>(def))
1338df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep};
1339a14057d4Sbixia1       if (isa<complex::Expm1Op>(def))
1340df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep};
1341952fa301SAart Bik       if (isa<math::Log1pOp>(def))
1342df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep};
1343d390035bSBixia Zheng       if (isa<complex::Log1pOp>(def))
1344df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep};
1345952fa301SAart Bik       if (isa<math::SinOp>(def))
1346df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kSinF, e), hasSpDep};
1347d390035bSBixia Zheng       if (isa<complex::SinOp>(def))
1348df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kSinC, e), hasSpDep};
1349952fa301SAart Bik       if (isa<math::TanhOp>(def))
1350df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep};
1351a14057d4Sbixia1       if (isa<complex::TanhOp>(def))
1352df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep};
1353a54f4eaeSMogball       if (isa<arith::NegFOp>(def))
1354df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std
1355d390035bSBixia Zheng       if (isa<complex::NegOp>(def))
1356df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kNegC, e), hasSpDep};
1357a54f4eaeSMogball       if (isa<arith::TruncFOp>(def))
1358df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep};
1359a54f4eaeSMogball       if (isa<arith::ExtFOp>(def))
1360df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep};
1361a54f4eaeSMogball       if (isa<arith::FPToSIOp>(def))
1362df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep};
1363a54f4eaeSMogball       if (isa<arith::FPToUIOp>(def))
1364df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep};
1365a54f4eaeSMogball       if (isa<arith::SIToFPOp>(def))
1366df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep};
1367a54f4eaeSMogball       if (isa<arith::UIToFPOp>(def))
1368df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep};
1369a54f4eaeSMogball       if (isa<arith::ExtSIOp>(def))
1370df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep};
1371a54f4eaeSMogball       if (isa<arith::ExtUIOp>(def))
1372df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep};
137353cc3a06SAart Bik       if (isa<arith::IndexCastOp>(def))
1374df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep};
1375a54f4eaeSMogball       if (isa<arith::TruncIOp>(def))
1376df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep};
137769edacbcSBixia Zheng       if (isa<complex::ImOp>(def))
1378df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kCIm, e), hasSpDep};
137969edacbcSBixia Zheng       if (isa<complex::ReOp>(def))
1380df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kCRe, e), hasSpDep};
1381a54f4eaeSMogball       if (isa<arith::BitcastOp>(def))
1382df11a2b4SPeiming Liu         return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep};
13832a288616SAart Bik       if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
13845c327422SAart Bik         if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
13855c327422SAart Bik             isAdmissibleBranch(unop, unop.getAbsentRegion()))
1386df11a2b4SPeiming Liu           return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep};
1387b8a021dbSAart Bik       }
138879193503SJim Kitchen       if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
13895c327422SAart Bik         if (isAdmissibleBranch(selop, selop.getRegion()))
1390df11a2b4SPeiming Liu           return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep};
139179193503SJim Kitchen       }
1392b8a021dbSAart Bik     }
13932a288616SAart Bik   }
1394*70e227a4SAart Bik 
1395b8a021dbSAart Bik   // Construct binary operations if subexpressions can be built.
13967d4da4e1SAart Bik   // See buildLattices() for an explanation of rejecting certain
1397c8bb2354SJim Kitchen   // division and shift operations.
1398266a7414SAart Bik   if (def->getNumOperands() == 2) {
1399fc83eda4SPeiming Liu     const auto [x, xSpVals] = buildTensorExp(op, def->getOperand(0));
1400fc83eda4SPeiming Liu     const auto [y, ySpVals] = buildTensorExp(op, def->getOperand(1));
1401fc83eda4SPeiming Liu     // For a conjunctive operation, it yields a "sparse" result if any operand
1402fc83eda4SPeiming Liu     // is sparse. For a disjunctive operation, it yields a "sparse" result if
1403fc83eda4SPeiming Liu     // all operands are sparse.
1404fc83eda4SPeiming Liu     bool conjSpVals = xSpVals || ySpVals;
1405fc83eda4SPeiming Liu     bool disjSpVals = xSpVals && ySpVals;
1406491d2701SKazu Hirata     if (x.has_value() && y.has_value()) {
1407b8cf7af9Swren romano       const ExprId e0 = *x;
1408b8cf7af9Swren romano       const ExprId e1 = *y;
1409a54f4eaeSMogball       if (isa<arith::MulFOp>(def))
1410fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kMulF, e0, e1), conjSpVals};
1411736c1b66SAart Bik       if (isa<complex::MulOp>(def))
1412fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kMulC, e0, e1), conjSpVals};
1413a54f4eaeSMogball       if (isa<arith::MulIOp>(def))
1414fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kMulI, e0, e1), conjSpVals};
1415a54f4eaeSMogball       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1416fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kDivF, e0, e1), conjSpVals};
1417d390035bSBixia Zheng       if (isa<complex::DivOp>(def) && !maybeZero(e1))
1418fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kDivC, e0, e1), conjSpVals};
1419a54f4eaeSMogball       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1420fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kDivS, e0, e1), conjSpVals};
1421a54f4eaeSMogball       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1422fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kDivU, e0, e1), conjSpVals};
1423a54f4eaeSMogball       if (isa<arith::AddFOp>(def))
1424fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kAddF, e0, e1), disjSpVals};
1425736c1b66SAart Bik       if (isa<complex::AddOp>(def))
1426fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kAddC, e0, e1), disjSpVals};
1427a54f4eaeSMogball       if (isa<arith::AddIOp>(def))
1428fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kAddI, e0, e1), disjSpVals};
1429a54f4eaeSMogball       if (isa<arith::SubFOp>(def))
1430fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kSubF, e0, e1), disjSpVals};
1431d390035bSBixia Zheng       if (isa<complex::SubOp>(def))
1432fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kSubC, e0, e1), disjSpVals};
1433a54f4eaeSMogball       if (isa<arith::SubIOp>(def))
1434fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kSubI, e0, e1), disjSpVals};
1435a54f4eaeSMogball       if (isa<arith::AndIOp>(def))
1436fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kAndI, e0, e1), conjSpVals};
1437a54f4eaeSMogball       if (isa<arith::OrIOp>(def))
1438fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kOrI, e0, e1), disjSpVals};
1439a54f4eaeSMogball       if (isa<arith::XOrIOp>(def))
1440fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kXorI, e0, e1), disjSpVals};
1441a54f4eaeSMogball       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1442fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kShrS, e0, e1), conjSpVals};
1443a54f4eaeSMogball       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1444fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kShrU, e0, e1), conjSpVals};
1445a54f4eaeSMogball       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1446fc83eda4SPeiming Liu         return {addExp(TensorExp::Kind::kShlI, e0, e1), conjSpVals};
1447faf7cd97SPeiming Liu       if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
1448faf7cd97SPeiming Liu         if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1449faf7cd97SPeiming Liu             ci.getPredicate() == arith::CmpIPredicate::sle &&
1450faf7cd97SPeiming Liu             ci.getPredicate() == arith::CmpIPredicate::sge &&
1451faf7cd97SPeiming Liu             ci.getPredicate() == arith::CmpIPredicate::ule &&
1452faf7cd97SPeiming Liu             ci.getPredicate() == arith::CmpIPredicate::uge) {
1453faf7cd97SPeiming Liu           // We can not sparsify comparison with equal, this is because 0 <= 0
1454faf7cd97SPeiming Liu           // yields true, and thus densifies the result.
1455df11a2b4SPeiming Liu           return {std::nullopt, false};
1456faf7cd97SPeiming Liu         }
1457faf7cd97SPeiming Liu 
1458df11a2b4SPeiming Liu         auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
1459faf7cd97SPeiming Liu                         ci.getPredicateAttr());
1460fc83eda4SPeiming Liu         return {e, conjSpVals};
1461faf7cd97SPeiming Liu       }
1462faf7cd97SPeiming Liu       if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
1463faf7cd97SPeiming Liu         if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1464faf7cd97SPeiming Liu             cf.getPredicate() == arith::CmpFPredicate::OGE &&
1465faf7cd97SPeiming Liu             cf.getPredicate() == arith::CmpFPredicate::OLE &&
1466faf7cd97SPeiming Liu             cf.getPredicate() == arith::CmpFPredicate::ONE &&
1467faf7cd97SPeiming Liu             cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1468faf7cd97SPeiming Liu             cf.getPredicate() == arith::CmpFPredicate::UGE &&
1469faf7cd97SPeiming Liu             cf.getPredicate() == arith::CmpFPredicate::ULE &&
1470faf7cd97SPeiming Liu             cf.getPredicate() == arith::CmpFPredicate::ORD &&
1471faf7cd97SPeiming Liu             cf.getPredicate() == arith::CmpFPredicate::UNO) {
1472faf7cd97SPeiming Liu           // We can not sparsify comparison with equal, this is because 0 <= 0
1473faf7cd97SPeiming Liu           // yields true, and thus densifies the result.
1474df11a2b4SPeiming Liu           return {std::nullopt, false};
1475faf7cd97SPeiming Liu         }
1476df11a2b4SPeiming Liu         auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
1477faf7cd97SPeiming Liu                         cf.getPredicateAttr());
1478fc83eda4SPeiming Liu         return {e, conjSpVals};
1479faf7cd97SPeiming Liu       }
14802a288616SAart Bik       if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
14815c327422SAart Bik         if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
148204235d07SJacques Pienaar             (binop.getLeftIdentity() ||
14835c327422SAart Bik              isAdmissibleBranch(binop, binop.getLeftRegion())) &&
148404235d07SJacques Pienaar             (binop.getRightIdentity() ||
14855c327422SAart Bik              isAdmissibleBranch(binop, binop.getRightRegion())))
1486fc83eda4SPeiming Liu           return {addExp(TensorExp::Kind::kBinary, e0, e1, def), conjSpVals};
1487266a7414SAart Bik       }
1488266a7414SAart Bik     }
14892a288616SAart Bik   }
1490*70e227a4SAart Bik 
1491c8bb2354SJim Kitchen   // Construct ternary operations if subexpressions can be built.
1492c8bb2354SJim Kitchen   if (def->getNumOperands() == 3) {
1493df11a2b4SPeiming Liu     const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1494df11a2b4SPeiming Liu     const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1495df11a2b4SPeiming Liu     const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
1496df11a2b4SPeiming Liu     bool hasSpDep = xDepSp || yDepSp || zDepSp;
1497c8bb2354SJim Kitchen     if (x.has_value() && y.has_value() && z.has_value()) {
1498b8cf7af9Swren romano       const ExprId e0 = *x;
1499b8cf7af9Swren romano       const ExprId e1 = *y;
1500c8bb2354SJim Kitchen       if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
15015c327422SAart Bik         if (isAdmissibleBranch(redop, redop.getRegion()))
1502df11a2b4SPeiming Liu           return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
1503c8bb2354SJim Kitchen       }
1504*70e227a4SAart Bik       if (auto selop = dyn_cast<arith::SelectOp>(def)) {
1505*70e227a4SAart Bik         // Recognize an integral or floating-point ReLu(x) = Max(x, 0)
1506*70e227a4SAart Bik         // operation inside a very specific ternary select operation.
1507*70e227a4SAart Bik         // TODO: capture MIN/MAX/ABS/RELU structure in a more generic way
1508*70e227a4SAart Bik         const auto &cnd = exp(*x);
1509*70e227a4SAart Bik         if (isGreater(cnd.kind, cnd.attr) &&
1510*70e227a4SAart Bik             exp(*y).kind == TensorExp::Kind::kTensor &&
1511*70e227a4SAart Bik             exp(*z).kind == TensorExp::Kind::kInvariant &&
1512*70e227a4SAart Bik             isCertainZero(exp(*z).val)) {
1513*70e227a4SAart Bik           const auto &a = exp(cnd.children.e0);
1514*70e227a4SAart Bik           const auto &b = exp(cnd.children.e1);
1515*70e227a4SAart Bik           if (a.kind == TensorExp::Kind::kTensor &&
1516*70e227a4SAart Bik               a.tensor == exp(*y).tensor &&
1517*70e227a4SAart Bik               b.kind == TensorExp::Kind::kInvariant && isCertainZero(b.val)) {
1518*70e227a4SAart Bik             return {addExp(TensorExp::Kind::kRelu, *y, detail::kInvalidId,
1519*70e227a4SAart Bik                            nullptr, cnd.attr),
1520*70e227a4SAart Bik                     yDepSp};
1521*70e227a4SAart Bik           }
1522*70e227a4SAart Bik         }
1523*70e227a4SAart Bik       }
1524c8bb2354SJim Kitchen     }
1525c8bb2354SJim Kitchen   }
1526df11a2b4SPeiming Liu 
1527df11a2b4SPeiming Liu   // If we reach here, we are dealing with an operation that is not currently
1528df11a2b4SPeiming Liu   // sparsifiable. We can still generate code for it if all its operands only
1529df11a2b4SPeiming Liu   // have dense dependencies (i.e., all the values are loaded from dense
1530df11a2b4SPeiming Liu   // tensors).
1531df11a2b4SPeiming Liu   if (def->getNumResults() != 1) // only handle single result operation.
1532df11a2b4SPeiming Liu     return {std::nullopt, false};
1533df11a2b4SPeiming Liu   SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
1534df11a2b4SPeiming Liu   // Builds all the sub-expressions
1535df11a2b4SPeiming Liu   for (Value operand : def->getOperands())
1536df11a2b4SPeiming Liu     subExp.push_back(buildTensorExp(op, operand));
1537df11a2b4SPeiming Liu 
1538df11a2b4SPeiming Liu   if (llvm::all_of(subExp,
1539df11a2b4SPeiming Liu                    [](auto e) { return e.first.has_value() && !e.second; })) {
1540df11a2b4SPeiming Liu     // All the subexpressions can be built and has *no* sparse dependencies.
1541df11a2b4SPeiming Liu     if (subExp.size() == 2) {
1542df11a2b4SPeiming Liu       auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1543df11a2b4SPeiming Liu                       *subExp[1].first, def);
1544df11a2b4SPeiming Liu       return {e, false};
1545df11a2b4SPeiming Liu     }
1546df11a2b4SPeiming Liu     if (subExp.size() == 1) {
1547df11a2b4SPeiming Liu       auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1548df11a2b4SPeiming Liu                       detail::kInvalidId, def);
1549df11a2b4SPeiming Liu       return {e, false};
1550df11a2b4SPeiming Liu     }
1551df11a2b4SPeiming Liu   }
1552*70e227a4SAart Bik 
1553266a7414SAart Bik   // Cannot build.
1554df11a2b4SPeiming Liu   return {std::nullopt, false};
1555266a7414SAart Bik }
1556266a7414SAart Bik 
insertYieldOp(RewriterBase & rewriter,Location loc,Region & region,ValueRange vals)1557e9fa5590SMatthias Springer static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1558e9fa5590SMatthias Springer                            ValueRange vals) {
15592c332660SJim Kitchen   // Make a clone of overlap region.
15602c332660SJim Kitchen   Region tmpRegion;
15614d67b278SJeff Niu   IRMapping mapper;
15622c332660SJim Kitchen   region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
15632c332660SJim Kitchen   Block &clonedBlock = tmpRegion.front();
15642c332660SJim Kitchen   YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
15652c332660SJim Kitchen   // Merge cloned block and return yield value.
15662c332660SJim Kitchen   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
156742c31d83SMatthias Springer   rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
1568a54930e6SPeiming Liu   Value val = clonedYield.getSingleResult();
15692c332660SJim Kitchen   rewriter.eraseOp(clonedYield);
15702c332660SJim Kitchen   rewriter.eraseOp(placeholder);
15712c332660SJim Kitchen   return val;
15722c332660SJim Kitchen }
15732c332660SJim Kitchen 
buildUnaryPresent(RewriterBase & rewriter,Location loc,Operation * op,Value v0)1574e9fa5590SMatthias Springer static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
15752c332660SJim Kitchen                                Operation *op, Value v0) {
15762c332660SJim Kitchen   if (!v0)
15772c332660SJim Kitchen     // Empty input value must be propagated.
15782c332660SJim Kitchen     return Value();
15792c332660SJim Kitchen   UnaryOp unop = cast<UnaryOp>(op);
158004235d07SJacques Pienaar   Region &presentRegion = unop.getPresentRegion();
15812c332660SJim Kitchen   if (presentRegion.empty())
15822c332660SJim Kitchen     // Uninitialized Value() will be interpreted as missing data in the
15832c332660SJim Kitchen     // output.
15842c332660SJim Kitchen     return Value();
15852c332660SJim Kitchen   return insertYieldOp(rewriter, loc, presentRegion, {v0});
15862c332660SJim Kitchen }
15872c332660SJim Kitchen 
buildBinaryOverlap(RewriterBase & rewriter,Location loc,Operation * op,Value v0,Value v1)1588e9fa5590SMatthias Springer static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
15892c332660SJim Kitchen                                 Operation *op, Value v0, Value v1) {
15902c332660SJim Kitchen   if (!v0 || !v1)
15912c332660SJim Kitchen     // Empty input values must be propagated.
15922c332660SJim Kitchen     return Value();
15932c332660SJim Kitchen   BinaryOp binop = cast<BinaryOp>(op);
159404235d07SJacques Pienaar   Region &overlapRegion = binop.getOverlapRegion();
15952c332660SJim Kitchen   if (overlapRegion.empty())
15962c332660SJim Kitchen     // Uninitialized Value() will be interpreted as missing data in the
15972c332660SJim Kitchen     // output.
15982c332660SJim Kitchen     return Value();
15992c332660SJim Kitchen   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
16002c332660SJim Kitchen }
16012c332660SJim Kitchen 
buildRelu(RewriterBase & rewriter,Location loc,Value v0,Attribute attr)1602*70e227a4SAart Bik static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0,
1603*70e227a4SAart Bik                        Attribute attr) {
1604*70e227a4SAart Bik   Type tp = v0.getType();
1605*70e227a4SAart Bik   auto zero =
1606*70e227a4SAart Bik       rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
1607*70e227a4SAart Bik   Value cmp;
1608*70e227a4SAart Bik   if (isa<FloatType>(tp)) {
1609*70e227a4SAart Bik     auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr);
1610*70e227a4SAart Bik     cmp = rewriter.create<arith::CmpFOp>(loc, pred, v0, zero);
1611*70e227a4SAart Bik   } else {
1612*70e227a4SAart Bik     auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr);
1613*70e227a4SAart Bik     cmp = rewriter.create<arith::CmpIOp>(loc, pred, v0, zero);
1614*70e227a4SAart Bik   }
1615*70e227a4SAart Bik   return rewriter.create<arith::SelectOp>(loc, cmp, v0, zero);
1616*70e227a4SAart Bik }
1617*70e227a4SAart Bik 
buildExp(RewriterBase & rewriter,Location loc,ExprId e,Value v0,Value v1) const1618b8cf7af9Swren romano Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
1619b8cf7af9Swren romano                        Value v1) const {
162046a384dfSwren romano   const auto &expr = exp(e);
162146a384dfSwren romano   switch (expr.kind) {
162206aa6ec8SAart Bik   // Leaf.
16231f58ae80Swren romano   case TensorExp::Kind::kTensor:
16241f58ae80Swren romano   case TensorExp::Kind::kInvariant:
16251f58ae80Swren romano   case TensorExp::Kind::kLoopVar:
1626faf7cd97SPeiming Liu   case TensorExp::Kind::kSynZero:
162745b3cfe8SAart Bik     llvm_unreachable("unexpected non-op");
162806aa6ec8SAart Bik   // Unary operations.
16291f58ae80Swren romano   case TensorExp::Kind::kAbsF:
163000f7096dSJeff Niu     return rewriter.create<math::AbsFOp>(loc, v0);
16311f58ae80Swren romano   case TensorExp::Kind::kAbsC: {
16325550c821STres Popp     auto type = cast<ComplexType>(v0.getType());
16335550c821STres Popp     auto eltType = cast<FloatType>(type.getElementType());
1634d390035bSBixia Zheng     return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1635d390035bSBixia Zheng   }
16361f58ae80Swren romano   case TensorExp::Kind::kAbsI:
16378dd07e36SAart Bik     return rewriter.create<math::AbsIOp>(loc, v0);
16381f58ae80Swren romano   case TensorExp::Kind::kCeilF:
1639a54f4eaeSMogball     return rewriter.create<math::CeilOp>(loc, v0);
16401f58ae80Swren romano   case TensorExp::Kind::kFloorF:
1641a54f4eaeSMogball     return rewriter.create<math::FloorOp>(loc, v0);
16421f58ae80Swren romano   case TensorExp::Kind::kSqrtF:
1643952fa301SAart Bik     return rewriter.create<math::SqrtOp>(loc, v0);
16441f58ae80Swren romano   case TensorExp::Kind::kSqrtC:
1645a14057d4Sbixia1     return rewriter.create<complex::SqrtOp>(loc, v0);
16461f58ae80Swren romano   case TensorExp::Kind::kExpm1F:
1647952fa301SAart Bik     return rewriter.create<math::ExpM1Op>(loc, v0);
16481f58ae80Swren romano   case TensorExp::Kind::kExpm1C:
1649a14057d4Sbixia1     return rewriter.create<complex::Expm1Op>(loc, v0);
16501f58ae80Swren romano   case TensorExp::Kind::kLog1pF:
1651952fa301SAart Bik     return rewriter.create<math::Log1pOp>(loc, v0);
16521f58ae80Swren romano   case TensorExp::Kind::kLog1pC:
1653d390035bSBixia Zheng     return rewriter.create<complex::Log1pOp>(loc, v0);
1654*70e227a4SAart Bik   case TensorExp::Kind::kRelu:
1655*70e227a4SAart Bik     return buildRelu(rewriter, loc, v0, expr.attr);
16561f58ae80Swren romano   case TensorExp::Kind::kSinF:
1657952fa301SAart Bik     return rewriter.create<math::SinOp>(loc, v0);
16581f58ae80Swren romano   case TensorExp::Kind::kSinC:
1659d390035bSBixia Zheng     return rewriter.create<complex::SinOp>(loc, v0);
16601f58ae80Swren romano   case TensorExp::Kind::kTanhF:
1661952fa301SAart Bik     return rewriter.create<math::TanhOp>(loc, v0);
16621f58ae80Swren romano   case TensorExp::Kind::kTanhC:
1663a14057d4Sbixia1     return rewriter.create<complex::TanhOp>(loc, v0);
16641f58ae80Swren romano   case TensorExp::Kind::kNegF:
1665a54f4eaeSMogball     return rewriter.create<arith::NegFOp>(loc, v0);
16661f58ae80Swren romano   case TensorExp::Kind::kNegC:
1667d390035bSBixia Zheng     return rewriter.create<complex::NegOp>(loc, v0);
16681f58ae80Swren romano   case TensorExp::Kind::kNegI: // no negi in std
1669a54f4eaeSMogball     return rewriter.create<arith::SubIOp>(
16707f1cb43dSAart Bik         loc,
1671a54f4eaeSMogball         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
16727f1cb43dSAart Bik                                            rewriter.getZeroAttr(v0.getType())),
16737f1cb43dSAart Bik         v0);
16741f58ae80Swren romano   case TensorExp::Kind::kTruncF:
16753c69bc4dSRiver Riddle     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
16761f58ae80Swren romano   case TensorExp::Kind::kExtF:
16773c69bc4dSRiver Riddle     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
16781f58ae80Swren romano   case TensorExp::Kind::kCastFS:
16793c69bc4dSRiver Riddle     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
16801f58ae80Swren romano   case TensorExp::Kind::kCastFU:
16813c69bc4dSRiver Riddle     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
16821f58ae80Swren romano   case TensorExp::Kind::kCastSF:
16833c69bc4dSRiver Riddle     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
16841f58ae80Swren romano   case TensorExp::Kind::kCastUF:
16853c69bc4dSRiver Riddle     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
16861f58ae80Swren romano   case TensorExp::Kind::kCastS:
16873c69bc4dSRiver Riddle     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
16881f58ae80Swren romano   case TensorExp::Kind::kCastU:
16893c69bc4dSRiver Riddle     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
16901f58ae80Swren romano   case TensorExp::Kind::kCastIdx:
169153cc3a06SAart Bik     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
16921f58ae80Swren romano   case TensorExp::Kind::kTruncI:
16933c69bc4dSRiver Riddle     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
16941f58ae80Swren romano   case TensorExp::Kind::kCIm: {
16955550c821STres Popp     auto type = cast<ComplexType>(v0.getType());
16965550c821STres Popp     auto eltType = cast<FloatType>(type.getElementType());
169769edacbcSBixia Zheng     return rewriter.create<complex::ImOp>(loc, eltType, v0);
169806aa6ec8SAart Bik   }
16991f58ae80Swren romano   case TensorExp::Kind::kCRe: {
17005550c821STres Popp     auto type = cast<ComplexType>(v0.getType());
17015550c821STres Popp     auto eltType = cast<FloatType>(type.getElementType());
170269edacbcSBixia Zheng     return rewriter.create<complex::ReOp>(loc, eltType, v0);
170369edacbcSBixia Zheng   }
17041f58ae80Swren romano   case TensorExp::Kind::kBitCast:
17053c69bc4dSRiver Riddle     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
170606aa6ec8SAart Bik   // Binary operations.
17071f58ae80Swren romano   case TensorExp::Kind::kMulF:
1708a54f4eaeSMogball     return rewriter.create<arith::MulFOp>(loc, v0, v1);
17091f58ae80Swren romano   case TensorExp::Kind::kMulC:
1710736c1b66SAart Bik     return rewriter.create<complex::MulOp>(loc, v0, v1);
17111f58ae80Swren romano   case TensorExp::Kind::kMulI:
1712a54f4eaeSMogball     return rewriter.create<arith::MulIOp>(loc, v0, v1);
17131f58ae80Swren romano   case TensorExp::Kind::kDivF:
1714a54f4eaeSMogball     return rewriter.create<arith::DivFOp>(loc, v0, v1);
17151f58ae80Swren romano   case TensorExp::Kind::kDivC:
1716d390035bSBixia Zheng     return rewriter.create<complex::DivOp>(loc, v0, v1);
17171f58ae80Swren romano   case TensorExp::Kind::kDivS:
1718a54f4eaeSMogball     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
17191f58ae80Swren romano   case TensorExp::Kind::kDivU:
1720a54f4eaeSMogball     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
17211f58ae80Swren romano   case TensorExp::Kind::kAddF:
1722a54f4eaeSMogball     return rewriter.create<arith::AddFOp>(loc, v0, v1);
17231f58ae80Swren romano   case TensorExp::Kind::kAddC:
1724736c1b66SAart Bik     return rewriter.create<complex::AddOp>(loc, v0, v1);
17251f58ae80Swren romano   case TensorExp::Kind::kAddI:
1726a54f4eaeSMogball     return rewriter.create<arith::AddIOp>(loc, v0, v1);
17271f58ae80Swren romano   case TensorExp::Kind::kSubF:
1728a54f4eaeSMogball     return rewriter.create<arith::SubFOp>(loc, v0, v1);
17291f58ae80Swren romano   case TensorExp::Kind::kSubC:
1730d390035bSBixia Zheng     return rewriter.create<complex::SubOp>(loc, v0, v1);
17311f58ae80Swren romano   case TensorExp::Kind::kSubI:
1732a54f4eaeSMogball     return rewriter.create<arith::SubIOp>(loc, v0, v1);
17331f58ae80Swren romano   case TensorExp::Kind::kAndI:
1734a54f4eaeSMogball     return rewriter.create<arith::AndIOp>(loc, v0, v1);
17351f58ae80Swren romano   case TensorExp::Kind::kOrI:
1736a54f4eaeSMogball     return rewriter.create<arith::OrIOp>(loc, v0, v1);
17371f58ae80Swren romano   case TensorExp::Kind::kXorI:
1738a54f4eaeSMogball     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
17391f58ae80Swren romano   case TensorExp::Kind::kShrS:
1740a54f4eaeSMogball     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
17411f58ae80Swren romano   case TensorExp::Kind::kShrU:
1742a54f4eaeSMogball     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
17431f58ae80Swren romano   case TensorExp::Kind::kShlI:
1744a54f4eaeSMogball     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1745faf7cd97SPeiming Liu   case TensorExp::Kind::kCmpI: {
1746faf7cd97SPeiming Liu     auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr);
1747faf7cd97SPeiming Liu     return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1);
1748faf7cd97SPeiming Liu   }
1749faf7cd97SPeiming Liu   case TensorExp::Kind::kCmpF: {
1750faf7cd97SPeiming Liu     auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr);
1751faf7cd97SPeiming Liu     return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1);
1752faf7cd97SPeiming Liu   }
17531f58ae80Swren romano   case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
175446a384dfSwren romano     return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(),
175546a384dfSwren romano                          {v0});
17561f58ae80Swren romano   case TensorExp::Kind::kUnary:
175746a384dfSwren romano     return buildUnaryPresent(rewriter, loc, expr.op, v0);
17581f58ae80Swren romano   case TensorExp::Kind::kSelect:
1759*70e227a4SAart Bik     return insertYieldOp(rewriter, loc,
1760*70e227a4SAart Bik                          cast<sparse_tensor::SelectOp>(expr.op).getRegion(),
176146a384dfSwren romano                          {v0});
17621f58ae80Swren romano   case TensorExp::Kind::kBinary:
176346a384dfSwren romano     return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
17641f58ae80Swren romano   case TensorExp::Kind::kReduce: {
176546a384dfSwren romano     ReduceOp redOp = cast<ReduceOp>(expr.op);
1766c8bb2354SJim Kitchen     return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1767c8bb2354SJim Kitchen   }
1768df11a2b4SPeiming Liu   case TensorExp::Kind::kDenseOp: {
1769df11a2b4SPeiming Liu     Operation *actualOp = expr.op;
1770df11a2b4SPeiming Liu     IRMapping mapping;
1771df11a2b4SPeiming Liu     mapping.map(actualOp->getOperand(0), v0);
1772df11a2b4SPeiming Liu     if (actualOp->getNumOperands() == 2)
1773df11a2b4SPeiming Liu       mapping.map(actualOp->getOperand(1), v1);
1774df11a2b4SPeiming Liu     return rewriter.clone(*actualOp, mapping)->getResult(0);
1775df11a2b4SPeiming Liu   }
177645b3cfe8SAart Bik   }
177745b3cfe8SAart Bik   llvm_unreachable("unexpected expression kind in build");
177845b3cfe8SAart Bik }
177945b3cfe8SAart Bik 
1780744146f6SGus Smith } // namespace sparse_tensor
1781744146f6SGus Smith } // namespace mlir
1782