1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 #include <optional>
18
19 namespace mlir {
20 namespace sparse_tensor {
21
22 enum class ExpArity {
23 kNullary,
24 kUnary,
25 kBinary,
26 };
27
getExpArity(TensorExp::Kind k)28 static ExpArity getExpArity(TensorExp::Kind k) {
29 switch (k) {
30 // Leaf.
31 case TensorExp::Kind::kTensor:
32 case TensorExp::Kind::kInvariant:
33 case TensorExp::Kind::kLoopVar:
34 case TensorExp::Kind::kSynZero:
35 return ExpArity::kNullary;
36 case TensorExp::Kind::kAbsF:
37 case TensorExp::Kind::kAbsC:
38 case TensorExp::Kind::kAbsI:
39 case TensorExp::Kind::kCeilF:
40 case TensorExp::Kind::kFloorF:
41 case TensorExp::Kind::kSqrtF:
42 case TensorExp::Kind::kSqrtC:
43 case TensorExp::Kind::kExpm1F:
44 case TensorExp::Kind::kExpm1C:
45 case TensorExp::Kind::kLog1pF:
46 case TensorExp::Kind::kLog1pC:
47 case TensorExp::Kind::kRelu:
48 case TensorExp::Kind::kSinF:
49 case TensorExp::Kind::kSinC:
50 case TensorExp::Kind::kTanhF:
51 case TensorExp::Kind::kTanhC:
52 case TensorExp::Kind::kTruncF:
53 case TensorExp::Kind::kExtF:
54 case TensorExp::Kind::kCastFS:
55 case TensorExp::Kind::kCastFU:
56 case TensorExp::Kind::kCastSF:
57 case TensorExp::Kind::kCastUF:
58 case TensorExp::Kind::kCastS:
59 case TensorExp::Kind::kCastU:
60 case TensorExp::Kind::kCastIdx:
61 case TensorExp::Kind::kTruncI:
62 case TensorExp::Kind::kCIm:
63 case TensorExp::Kind::kCRe:
64 case TensorExp::Kind::kBitCast:
65 case TensorExp::Kind::kBinaryBranch:
66 case TensorExp::Kind::kUnary:
67 case TensorExp::Kind::kSelect:
68 case TensorExp::Kind::kNegF:
69 case TensorExp::Kind::kNegC:
70 case TensorExp::Kind::kNegI:
71 return ExpArity::kUnary;
72 // Binary operations.
73 case TensorExp::Kind::kDivF:
74 case TensorExp::Kind::kDivC:
75 case TensorExp::Kind::kDivS:
76 case TensorExp::Kind::kDivU:
77 case TensorExp::Kind::kShrS:
78 case TensorExp::Kind::kShrU:
79 case TensorExp::Kind::kShlI:
80 case TensorExp::Kind::kMulF:
81 case TensorExp::Kind::kMulC:
82 case TensorExp::Kind::kMulI:
83 case TensorExp::Kind::kAndI:
84 case TensorExp::Kind::kAddF:
85 case TensorExp::Kind::kAddC:
86 case TensorExp::Kind::kAddI:
87 case TensorExp::Kind::kOrI:
88 case TensorExp::Kind::kXorI:
89 case TensorExp::Kind::kBinary:
90 case TensorExp::Kind::kReduce:
91 case TensorExp::Kind::kSubF:
92 case TensorExp::Kind::kSubC:
93 case TensorExp::Kind::kSubI:
94 case TensorExp::Kind::kCmpF:
95 case TensorExp::Kind::kCmpI:
96 case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands
97 return ExpArity::kBinary;
98 }
99 llvm_unreachable("unexpected kind");
100 }
101
102 //===----------------------------------------------------------------------===//
103 // Constructors.
104 //===----------------------------------------------------------------------===//
105
TensorExp(TensorExp::Kind k,unsigned x,ExprId y,Value v,Operation * o,Attribute a)106 TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
107 Operation *o, Attribute a)
108 : kind(k), val(v), op(o), attr(a) {
109 switch (kind) {
110 // Leaf.
111 case TensorExp::Kind::kTensor:
112 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
113 tensor = x;
114 return;
115 case TensorExp::Kind::kSynZero:
116 assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o);
117 return;
118 case TensorExp::Kind::kInvariant:
119 assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o);
120 return;
121 case TensorExp::Kind::kLoopVar:
122 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
123 loop = x;
124 return;
125 // Unary operations.
126 case TensorExp::Kind::kAbsF:
127 case TensorExp::Kind::kAbsC:
128 case TensorExp::Kind::kAbsI:
129 case TensorExp::Kind::kCeilF:
130 case TensorExp::Kind::kFloorF:
131 case TensorExp::Kind::kSqrtF:
132 case TensorExp::Kind::kSqrtC:
133 case TensorExp::Kind::kExpm1F:
134 case TensorExp::Kind::kExpm1C:
135 case TensorExp::Kind::kLog1pF:
136 case TensorExp::Kind::kLog1pC:
137 case TensorExp::Kind::kRelu:
138 case TensorExp::Kind::kSinF:
139 case TensorExp::Kind::kSinC:
140 case TensorExp::Kind::kTanhF:
141 case TensorExp::Kind::kTanhC:
142 case TensorExp::Kind::kNegF:
143 case TensorExp::Kind::kNegC:
144 case TensorExp::Kind::kNegI:
145 case TensorExp::Kind::kCIm:
146 case TensorExp::Kind::kCRe:
147 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
148 children.e0 = x;
149 children.e1 = y;
150 return;
151 case TensorExp::Kind::kTruncF:
152 case TensorExp::Kind::kExtF:
153 case TensorExp::Kind::kCastFS:
154 case TensorExp::Kind::kCastFU:
155 case TensorExp::Kind::kCastSF:
156 case TensorExp::Kind::kCastUF:
157 case TensorExp::Kind::kCastS:
158 case TensorExp::Kind::kCastU:
159 case TensorExp::Kind::kCastIdx:
160 case TensorExp::Kind::kTruncI:
161 case TensorExp::Kind::kBitCast:
162 assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o);
163 children.e0 = x;
164 children.e1 = y;
165 return;
166 case TensorExp::Kind::kBinaryBranch:
167 case TensorExp::Kind::kSelect:
168 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o);
169 children.e0 = x;
170 children.e1 = y;
171 return;
172 case TensorExp::Kind::kUnary:
173 // No assertion on y can be made, as the branching paths involve both
174 // a unary (`mapSet`) and binary (`disjSet`) pathway.
175 assert(x != detail::kInvalidId && !v && o);
176 children.e0 = x;
177 children.e1 = y;
178 return;
179 // Binary operations.
180 case TensorExp::Kind::kMulF:
181 case TensorExp::Kind::kMulC:
182 case TensorExp::Kind::kMulI:
183 case TensorExp::Kind::kDivF:
184 case TensorExp::Kind::kDivC:
185 case TensorExp::Kind::kDivS:
186 case TensorExp::Kind::kDivU:
187 case TensorExp::Kind::kAddF:
188 case TensorExp::Kind::kAddC:
189 case TensorExp::Kind::kAddI:
190 case TensorExp::Kind::kSubF:
191 case TensorExp::Kind::kSubC:
192 case TensorExp::Kind::kSubI:
193 case TensorExp::Kind::kAndI:
194 case TensorExp::Kind::kOrI:
195 case TensorExp::Kind::kXorI:
196 case TensorExp::Kind::kShrS:
197 case TensorExp::Kind::kShrU:
198 case TensorExp::Kind::kShlI:
199 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
200 children.e0 = x;
201 children.e1 = y;
202 return;
203 case TensorExp::Kind::kCmpF:
204 case TensorExp::Kind::kCmpI:
205 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
206 children.e0 = x;
207 children.e1 = y;
208 return;
209 case TensorExp::Kind::kBinary:
210 case TensorExp::Kind::kReduce:
211 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o);
212 children.e0 = x;
213 children.e1 = y;
214 return;
215 case TensorExp::Kind::kDenseOp:
216 assert(x != detail::kInvalidId && !v && o);
217 children.e0 = x;
218 children.e1 = y;
219 return;
220 }
221 llvm_unreachable("unexpected kind");
222 }
223
Merger(unsigned numInputOutputTensors,unsigned numLoops,unsigned maxLvlRank)224 Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
225 unsigned maxLvlRank)
226 : outTensor(numInputOutputTensors - 1),
227 syntheticTensor(numInputOutputTensors),
228 numTensors(numInputOutputTensors + 1), numLoops(numLoops),
229 hasSparseOut(false),
230 lvlTypes(numTensors,
231 std::vector<LevelType>(numLoops, LevelFormat::Undef)),
232 loopToLvl(numTensors,
233 std::vector<std::optional<Level>>(numLoops, std::nullopt)),
234 lvlToLoop(numTensors,
235 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
236 loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>(
237 numTensors, std::nullopt)),
238 levelToDependentLoop(numTensors,
239 std::vector<std::vector<LoopCoeffPair>>(
240 maxLvlRank, std::vector<LoopCoeffPair>())),
241 loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
242
243 //===----------------------------------------------------------------------===//
244 // Lattice methods.
245 //===----------------------------------------------------------------------===//
246
addTensorExp(TensorId t)247 ExprId Merger::addTensorExp(TensorId t) {
248 assert(isValidTensorId(t));
249 const ExprId eNew(tensorExps.size());
250 tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId,
251 Value(), nullptr, nullptr);
252 return eNew;
253 }
254
addLoopVarExp(LoopId i)255 ExprId Merger::addLoopVarExp(LoopId i) {
256 assert(isValidLoopId(i));
257 const ExprId eNew(tensorExps.size());
258 tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId,
259 Value(), nullptr, nullptr);
260 return eNew;
261 }
262
addInvariantExp(Value v)263 ExprId Merger::addInvariantExp(Value v) {
264 const ExprId eNew(tensorExps.size());
265 tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId,
266 detail::kInvalidId, v, nullptr, nullptr);
267 return eNew;
268 }
269
addSynZeroExp()270 ExprId Merger::addSynZeroExp() {
271 const ExprId eNew(tensorExps.size());
272 tensorExps.emplace_back(TensorExp::Kind::kSynZero, detail::kInvalidId,
273 detail::kInvalidId, Value(), nullptr, nullptr);
274 return eNew;
275 }
276
addExp(TensorExp::Kind k,ExprId e0,ExprId e1,Operation * op,Attribute attr)277 ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op,
278 Attribute attr) {
279 assert(k > TensorExp::Kind::kLoopVar);
280 const ExprId eNew(tensorExps.size());
281 tensorExps.emplace_back(k, e0, e1, Value(), op, attr);
282 return eNew;
283 }
284
addExp(TensorExp::Kind k,ExprId e,Value v,Operation * op,Attribute attr)285 ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op,
286 Attribute attr) {
287 assert(k > TensorExp::Kind::kLoopVar);
288 const ExprId eNew(tensorExps.size());
289 tensorExps.emplace_back(k, e, detail::kInvalidId, v, op, attr);
290 return eNew;
291 }
292
addLat(TensorId t,LoopId i,ExprId e)293 LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {
294 const LatPointId pNew(latPoints.size());
295 const unsigned size = numLoops * numTensors;
296 const TensorLoopId b = makeTensorLoopId(t, i);
297 latPoints.emplace_back(size, e);
298 latPoints[pNew].bits.set(b);
299 return pNew;
300 }
301
addLat(const BitVector & bits,ExprId e)302 LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
303 assert(bits.size() == numLoops * numTensors);
304 const LatPointId pNew(latPoints.size());
305 latPoints.emplace_back(bits, e);
306 return pNew;
307 }
308
addSet()309 LatSetId Merger::addSet() {
310 const LatSetId sNew(latSets.size());
311 latSets.emplace_back();
312 return sNew;
313 }
314
conjLat(ExprId e,LatPointId p0,LatPointId p1,Operation * op)315 LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1,
316 Operation *op) {
317 TensorExp::Kind kind = exp(e).kind;
318 Attribute attr = exp(e).attr;
319 const LatPointId pNew(latPoints.size());
320 const auto &point0 = lat(p0);
321 const auto &point1 = lat(p1);
322 BitVector bits(point0.bits);
323 bits |= point1.bits;
324 const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr);
325 latPoints.emplace_back(bits, ne);
326 return pNew;
327 }
328
conjSet(ExprId e,LatSetId s0,LatSetId s1,Operation * op)329 LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
330 const LatSetId sNew = addSet();
331 auto &setNew = latSets[sNew];
332 for (const LatPointId p0 : set(s0))
333 for (const LatPointId p1 : set(s1))
334 setNew.push_back(conjLat(e, p0, p1, op));
335 return sNew;
336 }
337
disjSet(ExprId e,LatSetId s0,LatSetId s1,Operation * op)338 LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
339 const LatSetId sNew = conjSet(e, s0, s1, op);
340 TensorExp::Kind kind = exp(e).kind;
341 // Followed by all in s0.
342 latSets[sNew].append(latSets[s0]);
343 // Map binary 0-y to unary -y.
344 // TODO: move this if-else logic into buildLattices
345 if (kind == TensorExp::Kind::kSubF)
346 s1 = mapSet(TensorExp::Kind::kNegF, s1);
347 else if (kind == TensorExp::Kind::kSubC)
348 s1 = mapSet(TensorExp::Kind::kNegC, s1);
349 else if (kind == TensorExp::Kind::kSubI)
350 s1 = mapSet(TensorExp::Kind::kNegI, s1);
351 // Followed by all in s1.
352 latSets[sNew].append(latSets[s1]);
353 return sNew;
354 }
355
disjSetWithZero(ExprId e,LatSetId s0,LatSetId s1)356 LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) {
357 assert(exp(e).kind == TensorExp::Kind::kCmpI ||
358 exp(e).kind == TensorExp::Kind::kCmpF);
359 const LatSetId sNew = conjSet(e, s0, s1, nullptr);
360
361 ExprId e0 = exp(e).children.e0;
362 ExprId e1 = exp(e).children.e1;
363 if (exp(e0).kind == TensorExp::Kind::kSynZero ||
364 exp(e1).kind == TensorExp::Kind::kSynZero) {
365 // lhs and rhs can't be synthetic zero at the same time.
366 assert(exp(e0).kind != exp(e1).kind);
367 // If one of the operands has already been assigned to zero (the
368 // element is absent in the corresponding operand), then we do not
369 // need to build disjunctive set for it.
370 return sNew;
371 }
372
373 auto lhsSet = mapBinWithSynZeroSet(e, s0, false);
374 auto rhsSet = mapBinWithSynZeroSet(e, s1, true);
375 latSets[sNew].append(latSets[lhsSet]);
376 latSets[sNew].append(latSets[rhsSet]);
377 return sNew;
378 }
379
combiSet(ExprId e,LatSetId s0,LatSetId s1,Operation * orig,bool includeLeft,TensorExp::Kind ltrans,Operation * opleft,bool includeRight,TensorExp::Kind rtrans,Operation * opright)380 LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
381 bool includeLeft, TensorExp::Kind ltrans,
382 Operation *opleft, bool includeRight,
383 TensorExp::Kind rtrans, Operation *opright) {
384 Attribute a = exp(e).attr;
385 const LatSetId sNew = conjSet(e, s0, s1, orig);
386 // Left Region.
387 if (includeLeft) {
388 if (opleft)
389 s0 = mapSet(ltrans, s0, Value(), opleft, a);
390 latSets[sNew].append(latSets[s0]);
391 }
392 // Right Region.
393 if (includeRight) {
394 if (opright)
395 s1 = mapSet(rtrans, s1, Value(), opright, a);
396 latSets[sNew].append(latSets[s1]);
397 }
398 return sNew;
399 }
400
mapSet(TensorExp::Kind kind,LatSetId s0,Value v,Operation * op,Attribute a)401 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
402 Operation *op, Attribute a) {
403 assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
404 TensorExp::Kind::kDenseOp == kind);
405 const LatSetId sNew = addSet();
406 auto &setNew = latSets[sNew];
407 for (const LatPointId p : set(s0)) {
408 const auto &point = latPoints[p];
409 setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op, a)));
410 }
411 return sNew;
412 }
413
mapBinWithSynZeroSet(ExprId e,LatSetId s0,bool lhsZero)414 LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) {
415 TensorExp::Kind kind = exp(e).kind;
416 Attribute a = exp(e).attr;
417 assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI);
418 // Must be a binary operation.
419 const LatSetId sNew = addSet();
420 auto &setNew = latSets[sNew];
421 const ExprId zeroExp = addSynZeroExp();
422 for (const LatPointId p : set(s0)) {
423 const auto &point = latPoints[p];
424 ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a)
425 : addExp(kind, point.exp, zeroExp, nullptr, a);
426 setNew.push_back(addLat(point.bits, newExp));
427 }
428 return sNew;
429 }
430
optimizeSet(LatSetId s0)431 LatSetId Merger::optimizeSet(LatSetId s0) {
432 const LatSetId sNew = addSet();
433 auto &setNew = latSets[sNew];
434 const auto &set0 = set(s0);
435 assert(!set0.empty());
436 const LatPointId p0 = set0[0];
437 for (const LatPointId p1 : set0) {
438 bool add = true;
439 if (p0 != p1) {
440 // Check whether this is a straightforward copy.
441 if (expIsTensor(latPoints[p1].exp, outTensor))
442 continue;
443 // Check whether this conjunction is already covered.
444 for (const LatPointId p2 : setNew) {
445 assert(!latGT(p1, p2)); // Lj => Li would be bad
446 if (onlyDenseDiff(p2, p1)) {
447 add = false;
448 break;
449 }
450 }
451 assert(!add || latGT(p0, p1));
452 }
453 if (add)
454 setNew.push_back(p1);
455 }
456 for (const LatPointId p : setNew)
457 latPoints[p].simple = simplifyCond(sNew, p);
458 return sNew;
459 }
460
simplifyCond(LatSetId s0,LatPointId p0)461 BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
462 // First determine if this lattice point is a *singleton*, i.e.,
463 // the last point in a lattice, no other is less than this one.
464 bool isSingleton = true;
465 for (const LatPointId p1 : set(s0)) {
466 if (p0 != p1 && latGT(p0, p1)) {
467 isSingleton = false;
468 break;
469 }
470 }
471
472 BitVector simple(latPoints[p0].bits);
473 bool reset = isSingleton && hasAnySparse(simple);
474 const TensorLoopId be = simple.size();
475 TensorLoopId offset = 0; // relative to the end
476 if (!reset)
477 // Starts resetting from a dense level, so that the first bit (if kept)
478 // is not undefined level-type.
479 for (unsigned b = 0; b < be; b++) {
480 if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
481 offset = be - b - 1; // relative to the end
482 break;
483 }
484 }
485
486 // Now apply the two basic rules. We also iterate the bits reversely to always
487 // keep the rightmost bit (which could possibly be a synthetic tensor).
488 for (unsigned b = be - 1 - offset, i = 0; i < be;
489 b = b == 0 ? be - 1 : b - 1, i++) {
490 // Slice on dense level has `locate` property as well, and can be optimized.
491 if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
492 const auto lt = getLvlType(b);
493 if (!lt.hasSparseSemantic()) {
494 if (reset)
495 simple.reset(b);
496 reset = true;
497 }
498 }
499 }
500 return simple;
501 }
502
latGT(LatPointId i,LatPointId j) const503 bool Merger::latGT(LatPointId i, LatPointId j) const {
504 const BitVector &bitsi = lat(i).bits;
505 const BitVector &bitsj = lat(j).bits;
506 assert(bitsi.size() == bitsj.size());
507 if (bitsi.count() > bitsj.count()) {
508 for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
509 if (bitsj[b] && !bitsi[b])
510 return false;
511 return true;
512 }
513 return false;
514 }
515
onlyDenseDiff(LatPointId i,LatPointId j) const516 bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
517 BitVector tmp(latPoints[j].bits);
518 tmp ^= latPoints[i].bits;
519 return !hasAnySparse(tmp);
520 }
521
expContainsTensor(ExprId e,TensorId t) const522 bool Merger::expContainsTensor(ExprId e, TensorId t) const {
523 const auto &expr = exp(e);
524 // First we check `expIsTensor`.
525 if (expr.kind == TensorExp::Kind::kTensor)
526 return expr.tensor == t;
527
528 switch (getExpArity(expr.kind)) {
529 case ExpArity::kNullary:
530 return false;
531 case ExpArity::kUnary: {
532 const ExprId e0 = expr.children.e0;
533 return expContainsTensor(e0, t);
534 }
535 case ExpArity::kBinary: {
536 const ExprId e0 = expr.children.e0;
537 const ExprId e1 = expr.children.e1;
538 return expContainsTensor(e0, t) || expContainsTensor(e1, t);
539 }
540 }
541 llvm_unreachable("unexpected arity");
542 }
543
hasNegateOnOut(ExprId e) const544 bool Merger::hasNegateOnOut(ExprId e) const {
545 const auto &expr = exp(e);
546 switch (expr.kind) {
547 case TensorExp::Kind::kNegF:
548 case TensorExp::Kind::kNegC:
549 case TensorExp::Kind::kNegI:
550 return expContainsTensor(expr.children.e0, outTensor);
551 case TensorExp::Kind::kSubF:
552 case TensorExp::Kind::kSubC:
553 case TensorExp::Kind::kSubI:
554 return expContainsTensor(expr.children.e1, outTensor) ||
555 hasNegateOnOut(expr.children.e0);
556 case TensorExp::Kind::kDenseOp: {
557 bool lhsNeg = hasNegateOnOut(expr.children.e0);
558 if (!lhsNeg && expr.children.e1 != detail::kInvalidId)
559 return hasNegateOnOut(expr.children.e1);
560 return lhsNeg;
561 }
562 default: {
563 switch (getExpArity(expr.kind)) {
564 case ExpArity::kNullary:
565 return false;
566 case ExpArity::kUnary:
567 return hasNegateOnOut(expr.children.e0);
568 case ExpArity::kBinary:
569 return hasNegateOnOut(expr.children.e0) ||
570 hasNegateOnOut(expr.children.e1);
571 }
572 }
573 }
574 llvm_unreachable("unexpected kind");
575 }
576
isSingleCondition(TensorId t,ExprId e) const577 bool Merger::isSingleCondition(TensorId t, ExprId e) const {
578 assert(isValidTensorId(t));
579 const auto &expr = exp(e);
580 switch (expr.kind) {
581 // Leaf.
582 case TensorExp::Kind::kTensor:
583 return expr.tensor == t;
584 case TensorExp::Kind::kInvariant:
585 case TensorExp::Kind::kLoopVar:
586 case TensorExp::Kind::kSynZero:
587 return false;
588 // Unary operations.
589 case TensorExp::Kind::kAbsF:
590 case TensorExp::Kind::kAbsC:
591 case TensorExp::Kind::kAbsI:
592 case TensorExp::Kind::kCeilF:
593 case TensorExp::Kind::kFloorF:
594 case TensorExp::Kind::kSqrtF:
595 case TensorExp::Kind::kSqrtC:
596 case TensorExp::Kind::kExpm1F:
597 case TensorExp::Kind::kExpm1C:
598 case TensorExp::Kind::kLog1pF:
599 case TensorExp::Kind::kLog1pC:
600 case TensorExp::Kind::kRelu:
601 case TensorExp::Kind::kSinF:
602 case TensorExp::Kind::kSinC:
603 case TensorExp::Kind::kTanhF:
604 case TensorExp::Kind::kTanhC:
605 case TensorExp::Kind::kNegF:
606 case TensorExp::Kind::kNegC:
607 case TensorExp::Kind::kNegI:
608 case TensorExp::Kind::kTruncF:
609 case TensorExp::Kind::kExtF:
610 case TensorExp::Kind::kCastFS:
611 case TensorExp::Kind::kCastFU:
612 case TensorExp::Kind::kCastSF:
613 case TensorExp::Kind::kCastUF:
614 case TensorExp::Kind::kCastS:
615 case TensorExp::Kind::kCastU:
616 case TensorExp::Kind::kCastIdx:
617 case TensorExp::Kind::kTruncI:
618 case TensorExp::Kind::kCIm:
619 case TensorExp::Kind::kCRe:
620 case TensorExp::Kind::kBitCast:
621 case TensorExp::Kind::kUnary:
622 return isSingleCondition(t, expr.children.e0);
623 case TensorExp::Kind::kBinaryBranch:
624 case TensorExp::Kind::kSelect:
625 return false;
626 // Binary operations.
627 case TensorExp::Kind::kDivF: // note: x / c only
628 case TensorExp::Kind::kDivC:
629 case TensorExp::Kind::kDivS:
630 case TensorExp::Kind::kDivU:
631 assert(!maybeZero(expr.children.e1));
632 return isSingleCondition(t, expr.children.e0);
633 case TensorExp::Kind::kShrS: // note: x >> inv only
634 case TensorExp::Kind::kShrU:
635 case TensorExp::Kind::kShlI:
636 assert(isInvariant(expr.children.e1));
637 return isSingleCondition(t, expr.children.e0);
638 case TensorExp::Kind::kMulF:
639 case TensorExp::Kind::kMulC:
640 case TensorExp::Kind::kMulI:
641 case TensorExp::Kind::kAndI:
642 case TensorExp::Kind::kReduce:
643 if (isSingleCondition(t, expr.children.e0))
644 return isSingleCondition(t, expr.children.e1) ||
645 isInvariant(expr.children.e1);
646 if (isSingleCondition(t, expr.children.e1))
647 return isInvariant(expr.children.e0);
648 return false;
649 case TensorExp::Kind::kAddF:
650 case TensorExp::Kind::kAddC:
651 case TensorExp::Kind::kAddI:
652 return isSingleCondition(t, expr.children.e0) &&
653 isSingleCondition(t, expr.children.e1);
654 case TensorExp::Kind::kSubF:
655 case TensorExp::Kind::kSubC:
656 case TensorExp::Kind::kSubI:
657 case TensorExp::Kind::kOrI:
658 case TensorExp::Kind::kXorI:
659 case TensorExp::Kind::kCmpF:
660 case TensorExp::Kind::kCmpI:
661 case TensorExp::Kind::kBinary:
662 return false;
663 case TensorExp::Kind::kDenseOp:
664 // Since Merger guarantees all the operands of the kDenseOp to be dense, the
665 // operation must be single-condition.
666 return true;
667 }
668 llvm_unreachable("unexpected kind");
669 }
670
hasAnySparse(const BitVector & bits) const671 bool Merger::hasAnySparse(const BitVector &bits) const {
672 for (TensorLoopId b : bits.set_bits()) {
673 const auto lt = getLvlType(b);
674 if (lt.hasSparseSemantic())
675 return true;
676 }
677 return hasSparseIdxReduction(bits);
678 }
679
hasSparseIdxReduction(const BitVector & bits) const680 bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
681 for (TensorLoopId b : bits.set_bits())
682 if (isSparseLvlWithNonTrivialIdxExp(b))
683 return true;
684 return false;
685 }
686
687 #ifndef NDEBUG
688
689 //===----------------------------------------------------------------------===//
690 // Print methods (for debugging).
691 //===----------------------------------------------------------------------===//
692
kindToOpSymbol(TensorExp::Kind kind)693 static const char *kindToOpSymbol(TensorExp::Kind kind) {
694 switch (kind) {
695 // Leaf.
696 case TensorExp::Kind::kTensor:
697 return "tensor";
698 case TensorExp::Kind::kInvariant:
699 return "invariant";
700 case TensorExp::Kind::kLoopVar:
701 return "index";
702 case TensorExp::Kind::kSynZero:
703 return "0";
704 // Unary operations.
705 case TensorExp::Kind::kAbsF:
706 case TensorExp::Kind::kAbsC:
707 case TensorExp::Kind::kAbsI:
708 return "abs";
709 case TensorExp::Kind::kCeilF:
710 return "ceil";
711 case TensorExp::Kind::kFloorF:
712 return "floor";
713 case TensorExp::Kind::kSqrtF:
714 case TensorExp::Kind::kSqrtC:
715 return "sqrt";
716 case TensorExp::Kind::kExpm1F:
717 case TensorExp::Kind::kExpm1C:
718 return "expm1";
719 case TensorExp::Kind::kLog1pF:
720 case TensorExp::Kind::kLog1pC:
721 return "log1p";
722 case TensorExp::Kind::kRelu:
723 return "relu";
724 case TensorExp::Kind::kSinF:
725 case TensorExp::Kind::kSinC:
726 return "sin";
727 case TensorExp::Kind::kTanhF:
728 case TensorExp::Kind::kTanhC:
729 return "tanh";
730 case TensorExp::Kind::kNegF:
731 case TensorExp::Kind::kNegC:
732 case TensorExp::Kind::kNegI:
733 return "-";
734 case TensorExp::Kind::kTruncF:
735 case TensorExp::Kind::kExtF:
736 case TensorExp::Kind::kCastFS:
737 case TensorExp::Kind::kCastFU:
738 case TensorExp::Kind::kCastSF:
739 case TensorExp::Kind::kCastUF:
740 case TensorExp::Kind::kCastS:
741 case TensorExp::Kind::kCastU:
742 case TensorExp::Kind::kCastIdx:
743 case TensorExp::Kind::kTruncI:
744 case TensorExp::Kind::kCIm:
745 return "complex.im";
746 case TensorExp::Kind::kCRe:
747 return "complex.re";
748 case TensorExp::Kind::kBitCast:
749 return "cast";
750 case TensorExp::Kind::kBinaryBranch:
751 return "binary_branch";
752 case TensorExp::Kind::kUnary:
753 return "unary";
754 case TensorExp::Kind::kSelect:
755 return "select";
756 // Binary operations.
757 case TensorExp::Kind::kMulF:
758 case TensorExp::Kind::kMulC:
759 case TensorExp::Kind::kMulI:
760 return "*";
761 case TensorExp::Kind::kDivF:
762 case TensorExp::Kind::kDivC:
763 case TensorExp::Kind::kDivS:
764 case TensorExp::Kind::kDivU:
765 return "/";
766 case TensorExp::Kind::kAddF:
767 case TensorExp::Kind::kAddC:
768 case TensorExp::Kind::kAddI:
769 return "+";
770 case TensorExp::Kind::kSubF:
771 case TensorExp::Kind::kSubC:
772 case TensorExp::Kind::kSubI:
773 return "-";
774 case TensorExp::Kind::kAndI:
775 return "&";
776 case TensorExp::Kind::kOrI:
777 return "|";
778 case TensorExp::Kind::kXorI:
779 return "^";
780 case TensorExp::Kind::kShrS:
781 return "a>>";
782 case TensorExp::Kind::kShrU:
783 return ">>";
784 case TensorExp::Kind::kShlI:
785 return "<<";
786 case TensorExp::Kind::kCmpF:
787 case TensorExp::Kind::kCmpI:
788 return "cmp";
789 case TensorExp::Kind::kBinary:
790 return "binary";
791 case TensorExp::Kind::kReduce:
792 return "reduce";
793 case TensorExp::Kind::kDenseOp:
794 return "dense";
795 }
796 llvm_unreachable("unexpected kind for symbol");
797 }
798
dumpExp(ExprId e) const799 void Merger::dumpExp(ExprId e) const {
800 const auto &expr = exp(e);
801 switch (expr.kind) {
802 // Leaf.
803 case TensorExp::Kind::kTensor:
804 if (expr.tensor == syntheticTensor)
805 llvm::dbgs() << "synthetic_";
806 else if (expr.tensor == outTensor)
807 llvm::dbgs() << "output_";
808 llvm::dbgs() << "tensor_" << expr.tensor;
809 break;
810 case TensorExp::Kind::kInvariant:
811 llvm::dbgs() << "invariant";
812 break;
813 case TensorExp::Kind::kSynZero:
814 llvm::dbgs() << "0";
815 break;
816 case TensorExp::Kind::kLoopVar:
817 llvm::dbgs() << "loopvar_" << expr.loop;
818 break;
819 // Unary operations.
820 case TensorExp::Kind::kAbsF:
821 case TensorExp::Kind::kAbsC:
822 case TensorExp::Kind::kAbsI:
823 case TensorExp::Kind::kCeilF:
824 case TensorExp::Kind::kFloorF:
825 case TensorExp::Kind::kSqrtF:
826 case TensorExp::Kind::kSqrtC:
827 case TensorExp::Kind::kExpm1F:
828 case TensorExp::Kind::kExpm1C:
829 case TensorExp::Kind::kLog1pF:
830 case TensorExp::Kind::kLog1pC:
831 case TensorExp::Kind::kRelu:
832 case TensorExp::Kind::kSinF:
833 case TensorExp::Kind::kSinC:
834 case TensorExp::Kind::kTanhF:
835 case TensorExp::Kind::kTanhC:
836 case TensorExp::Kind::kNegF:
837 case TensorExp::Kind::kNegC:
838 case TensorExp::Kind::kNegI:
839 case TensorExp::Kind::kTruncF:
840 case TensorExp::Kind::kExtF:
841 case TensorExp::Kind::kCastFS:
842 case TensorExp::Kind::kCastFU:
843 case TensorExp::Kind::kCastSF:
844 case TensorExp::Kind::kCastUF:
845 case TensorExp::Kind::kCastS:
846 case TensorExp::Kind::kCastU:
847 case TensorExp::Kind::kCastIdx:
848 case TensorExp::Kind::kTruncI:
849 case TensorExp::Kind::kCIm:
850 case TensorExp::Kind::kCRe:
851 case TensorExp::Kind::kBitCast:
852 case TensorExp::Kind::kBinaryBranch:
853 case TensorExp::Kind::kUnary:
854 case TensorExp::Kind::kSelect:
855 llvm::dbgs() << kindToOpSymbol(expr.kind) << " ";
856 dumpExp(expr.children.e0);
857 break;
858 // Binary operations.
859 case TensorExp::Kind::kMulF:
860 case TensorExp::Kind::kMulC:
861 case TensorExp::Kind::kMulI:
862 case TensorExp::Kind::kDivF:
863 case TensorExp::Kind::kDivC:
864 case TensorExp::Kind::kDivS:
865 case TensorExp::Kind::kDivU:
866 case TensorExp::Kind::kAddF:
867 case TensorExp::Kind::kAddC:
868 case TensorExp::Kind::kAddI:
869 case TensorExp::Kind::kSubF:
870 case TensorExp::Kind::kSubC:
871 case TensorExp::Kind::kSubI:
872 case TensorExp::Kind::kAndI:
873 case TensorExp::Kind::kOrI:
874 case TensorExp::Kind::kXorI:
875 case TensorExp::Kind::kShrS:
876 case TensorExp::Kind::kShrU:
877 case TensorExp::Kind::kShlI:
878 case TensorExp::Kind::kCmpF:
879 case TensorExp::Kind::kCmpI:
880 case TensorExp::Kind::kBinary:
881 case TensorExp::Kind::kReduce:
882 case TensorExp::Kind::kDenseOp:
883 llvm::dbgs() << "(";
884 dumpExp(expr.children.e0);
885 llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
886 if (expr.attr)
887 llvm::dbgs() << "{" << expr.attr << "}";
888 if (expr.children.e1 != detail::kInvalidId) {
889 llvm::dbgs() << " ";
890 dumpExp(expr.children.e1);
891 llvm::dbgs() << ")";
892 } else {
893 assert(expr.kind == TensorExp::Kind::kDenseOp);
894 }
895 break;
896 }
897 }
898
dumpLat(LatPointId p) const899 void Merger::dumpLat(LatPointId p) const {
900 const auto &point = lat(p);
901 llvm::dbgs() << "lat(";
902 dumpBits(point.bits);
903 llvm::dbgs() << " :";
904 dumpBits(point.simple);
905 llvm::dbgs() << " : ";
906 dumpExp(point.exp);
907 llvm::dbgs() << " )\n";
908 }
909
dumpSet(LatSetId s) const910 void Merger::dumpSet(LatSetId s) const {
911 const auto &ss = set(s);
912 llvm::dbgs() << "{ #" << ss.size() << "\n";
913 for (const LatPointId p : ss) {
914 llvm::dbgs() << " ";
915 dumpLat(p);
916 }
917 llvm::dbgs() << "}\n";
918 }
919
dumpBits(const BitVector & bits) const920 void Merger::dumpBits(const BitVector &bits) const {
921 for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
922 if (bits[b]) {
923 const TensorId t = tensor(b);
924 const LoopId i = loop(b);
925 const auto lt = lvlTypes[t][i];
926 if (isLvlWithNonTrivialIdxExp(b))
927 llvm::dbgs() << " DEP_" << t << "_" << i;
928 else
929 llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt);
930 }
931 }
932 }
933
934 #endif // NDEBUG
935
936 //===----------------------------------------------------------------------===//
937 // Builder methods.
938 //===----------------------------------------------------------------------===//
939
buildLattices(ExprId e,LoopId i)940 LatSetId Merger::buildLattices(ExprId e, LoopId i) {
941 // NOTE: The `expr` reference will be invalidated by recursive calls
942 // (and any other method that may add new expressions); therefore, the
943 // code below must make sure to copy fields of `expr` into local variables
944 // before making any recursive calls.
945 const auto &expr = exp(e);
946 const TensorExp::Kind kind = expr.kind;
947 switch (kind) {
948 // Leaf.
949 case TensorExp::Kind::kTensor:
950 case TensorExp::Kind::kInvariant:
951 case TensorExp::Kind::kSynZero:
952 case TensorExp::Kind::kLoopVar: {
953 // Either the loop-var is really used in the tensor expression, or it is
954 // set to the undefined loop-var in that level. An invariant expression,
955 // a proper index value, and a truly dynamic sparse output tensor are set
956 // to a synthetic tensor with undefined indices only to ensure the
957 // iteration space is not skipped as a result of their contents.
958 const LatSetId s = addSet();
959 TensorId t = syntheticTensor;
960 if (kind == TensorExp::Kind::kTensor) {
961 t = expr.tensor;
962 if (hasSparseOut && t == outTensor)
963 t = syntheticTensor;
964 }
965 latSets[s].push_back(addLat(t, i, e));
966 return s;
967 }
968 // Unary operations.
969 case TensorExp::Kind::kAbsF:
970 case TensorExp::Kind::kAbsC:
971 case TensorExp::Kind::kAbsI:
972 case TensorExp::Kind::kCeilF:
973 case TensorExp::Kind::kFloorF:
974 case TensorExp::Kind::kSqrtF:
975 case TensorExp::Kind::kSqrtC:
976 case TensorExp::Kind::kExpm1F:
977 case TensorExp::Kind::kExpm1C:
978 case TensorExp::Kind::kLog1pF:
979 case TensorExp::Kind::kLog1pC:
980 case TensorExp::Kind::kRelu:
981 case TensorExp::Kind::kSinF:
982 case TensorExp::Kind::kSinC:
983 case TensorExp::Kind::kTanhF:
984 case TensorExp::Kind::kTanhC:
985 case TensorExp::Kind::kNegF:
986 case TensorExp::Kind::kNegC:
987 case TensorExp::Kind::kNegI:
988 case TensorExp::Kind::kTruncF:
989 case TensorExp::Kind::kExtF:
990 case TensorExp::Kind::kCastFS:
991 case TensorExp::Kind::kCastFU:
992 case TensorExp::Kind::kCastSF:
993 case TensorExp::Kind::kCastUF:
994 case TensorExp::Kind::kCastS:
995 case TensorExp::Kind::kCastU:
996 case TensorExp::Kind::kCastIdx:
997 case TensorExp::Kind::kTruncI:
998 case TensorExp::Kind::kCIm:
999 case TensorExp::Kind::kCRe:
1000 case TensorExp::Kind::kBitCast:
1001 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
1002 // lattice set of the operand through the operator into a new set.
1003 //
1004 // -y|!y | y |
1005 // --+---+---+
1006 // | 0 |-y |
1007 {
1008 const ExprId e0 = expr.children.e0;
1009 const Value v = expr.val;
1010 Attribute a = expr.attr;
1011 return mapSet(kind, buildLattices(e0, i), v, nullptr, a);
1012 }
1013 case TensorExp::Kind::kBinaryBranch:
1014 case TensorExp::Kind::kSelect:
1015 // The left or right half of a binary operation which has already
1016 // been split into separate operations for each region.
1017 {
1018 const ExprId e0 = expr.children.e0;
1019 Operation *const op = expr.op;
1020 return mapSet(kind, buildLattices(e0, i), Value(), op);
1021 }
1022 case TensorExp::Kind::kUnary:
1023 // A custom unary operation.
1024 //
1025 // op y| !y | y |
1026 // ----+----------+------------+
1027 // | absent() | present(y) |
1028 {
1029 const ExprId e0 = expr.children.e0;
1030 UnaryOp unop = cast<UnaryOp>(expr.op);
1031 const LatSetId child0 = buildLattices(e0, i);
1032 Region &absentRegion = unop.getAbsentRegion();
1033 if (absentRegion.empty()) {
1034 // Simple mapping over existing values.
1035 return mapSet(kind, child0, Value(), unop);
1036 }
1037 // Use a disjunction with `unop` on the left and the absent value as an
1038 // invariant on the right.
1039 Block &absentBlock = absentRegion.front();
1040 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
1041 const Value absentVal = absentYield.getSingleResult();
1042 const ExprId rhs = addInvariantExp(absentVal);
1043 return disjSet(e, child0, buildLattices(rhs, i), unop);
1044 }
1045 // Binary operations.
1046 case TensorExp::Kind::kMulF:
1047 case TensorExp::Kind::kMulC:
1048 case TensorExp::Kind::kMulI:
1049 case TensorExp::Kind::kAndI:
1050 // A multiplicative operation only needs to be performed
1051 // for the conjunction of sparse iteration spaces.
1052 //
1053 // x*y|!y | y |
1054 // ---+---+---+
1055 // !x | 0 | 0 |
1056 // x | 0 |x*y|
1057 //
1058 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
1059 {
1060 const ExprId e0 = expr.children.e0;
1061 const ExprId e1 = expr.children.e1;
1062 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1063 }
1064 case TensorExp::Kind::kDivF:
1065 case TensorExp::Kind::kDivC:
1066 case TensorExp::Kind::kDivS:
1067 case TensorExp::Kind::kDivU:
1068 // A division is tricky, since 0/0, 0/c, c/0 all have
1069 // specific outcomes for floating-point and integers.
1070 // Thus, we need to traverse the full iteration space.
1071 //
1072 // x/y|!y | y |
1073 // ---+---+---+
1074 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
1075 // x |x/0|x/y| INT: x/0=exception for any x
1076 //
1077 // TODO: for now we "fixed" this by only accepting x/c cases
1078 // during expression building, so that the conjunction
1079 // rules applies (viz. x/c = x*(1/c) as far as lattice
1080 // construction is concerned).
1081 {
1082 const ExprId e0 = expr.children.e0;
1083 const ExprId e1 = expr.children.e1;
1084 assert(!maybeZero(e1));
1085 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1086 }
1087 case TensorExp::Kind::kAddF:
1088 case TensorExp::Kind::kAddC:
1089 case TensorExp::Kind::kAddI:
1090 case TensorExp::Kind::kSubF:
1091 case TensorExp::Kind::kSubC:
1092 case TensorExp::Kind::kSubI:
1093 case TensorExp::Kind::kOrI:
1094 case TensorExp::Kind::kXorI:
1095 // An additive operation needs to be performed
1096 // for the disjunction of sparse iteration spaces.
1097 //
1098 // x+y|!y | y | x-y|!y | y |
1099 // ---+---+---+ ---+---+---+
1100 // !x | 0 | y | !x | 0 |-y |
1101 // x | x |x+y| x | x |x-y|
1102 {
1103 const ExprId e0 = expr.children.e0;
1104 const ExprId e1 = expr.children.e1;
1105 return disjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1106 }
1107 case TensorExp::Kind::kCmpF:
1108 case TensorExp::Kind::kCmpI:
1109 // A comparison operation needs to be performed
1110 // for the disjunction of sparse iteration spaces.
1111 //
1112 // x < y | !y | y |
1113 // -------+-------+-------+
1114 // !x | 0 | 0 < y |
1115 // x | x < 0 | x < y |
1116 {
1117 const ExprId e0 = expr.children.e0;
1118 const ExprId e1 = expr.children.e1;
1119 return disjSetWithZero(e, buildLattices(e0, i), buildLattices(e1, i));
1120 }
1121 case TensorExp::Kind::kShrS:
1122 case TensorExp::Kind::kShrU:
1123 case TensorExp::Kind::kShlI:
1124 // A shift operation by an invariant amount (viz. tensor expressions
1125 // can only occur at the left-hand-side of the operator) can be handled
1126 // with the conjunction rule.
1127 {
1128 const ExprId e0 = expr.children.e0;
1129 const ExprId e1 = expr.children.e1;
1130 assert(isInvariant(e1));
1131 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1132 }
1133 case TensorExp::Kind::kBinary:
1134 // A custom binary operation.
1135 //
1136 // x op y| !y | y |
1137 // ------+---------+--------------+
1138 // !x | empty | right(y) |
1139 // x | left(x) | overlap(x,y) |
1140 {
1141 const ExprId e0 = expr.children.e0;
1142 const ExprId e1 = expr.children.e1;
1143 BinaryOp binop = cast<BinaryOp>(expr.op);
1144 const LatSetId child0 = buildLattices(e0, i);
1145 const LatSetId child1 = buildLattices(e1, i);
1146 Region &leftRegion = binop.getLeftRegion();
1147 Region &rightRegion = binop.getRightRegion();
1148 // Left Region.
1149 Operation *leftYield = nullptr;
1150 if (!leftRegion.empty()) {
1151 Block &leftBlock = leftRegion.front();
1152 leftYield = leftBlock.getTerminator();
1153 }
1154 // Right Region.
1155 Operation *rightYield = nullptr;
1156 if (!rightRegion.empty()) {
1157 Block &rightBlock = rightRegion.front();
1158 rightYield = rightBlock.getTerminator();
1159 }
1160 bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
1161 bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
1162 return combiSet(e, child0, child1, binop, includeLeft,
1163 TensorExp::Kind::kBinaryBranch, leftYield, includeRight,
1164 TensorExp::Kind::kBinaryBranch, rightYield);
1165 }
1166 case TensorExp::Kind::kReduce:
1167 // A custom reduce operation.
1168 {
1169 const ExprId e0 = expr.children.e0;
1170 const ExprId e1 = expr.children.e1;
1171 Operation *const op = expr.op;
1172 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1173 }
1174 case TensorExp::Kind::kDenseOp: {
1175 // It does not really matter whether we use conjunctive/disjunctive set
1176 // here, as all the operands of kDenseOp must be dense, the disjunctive set
1177 // will be optimized into conjunctive set eventually.
1178 if (expr.children.e1 == detail::kInvalidId) {
1179 const ExprId e0 = expr.children.e0;
1180 Operation *const op = expr.op;
1181 return mapSet(kind, buildLattices(e0, i), Value(), op);
1182 }
1183
1184 const ExprId e0 = expr.children.e0;
1185 const ExprId e1 = expr.children.e1;
1186 Operation *const op = expr.op;
1187 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1188 }
1189 }
1190 llvm_unreachable("unexpected expression kind");
1191 }
1192
buildTensorExpFromLinalg(linalg::GenericOp op)1193 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
1194 // Build the linalg semantics backward from yield.
1195 Operation *yield = op.getRegion().front().getTerminator();
1196 assert(isa<linalg::YieldOp>(yield));
1197 return buildTensorExp(op, yield->getOperand(0)).first;
1198 }
1199
1200 /// Only returns true if we are certain this is a zero.
isCertainZero(Value val)1201 static bool isCertainZero(Value val) {
1202 if (auto c = val.getDefiningOp<complex::ConstantOp>()) {
1203 ArrayAttr arrayAttr = c.getValue();
1204 return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1205 cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1206 }
1207 if (auto c = val.getDefiningOp<arith::ConstantIntOp>())
1208 return c.value() == 0;
1209 if (auto c = val.getDefiningOp<arith::ConstantFloatOp>())
1210 return c.value().isZero();
1211 return false;
1212 }
1213
1214 /// Only returns false if we are certain this is a nonzero.
maybeZero(ExprId e) const1215 bool Merger::maybeZero(ExprId e) const {
1216 const auto &expr = exp(e);
1217 if (expr.kind == TensorExp::Kind::kInvariant) {
1218 // Note that this is different from isCertainZero() in a subtle
1219 // way by always returning true for non-constants.
1220 if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1221 ArrayAttr arrayAttr = c.getValue();
1222 return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1223 cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1224 }
1225 if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
1226 return c.value() == 0;
1227 if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>())
1228 return c.value().isZero();
1229 }
1230 return true;
1231 }
1232
inferType(ExprId e,Value src) const1233 Type Merger::inferType(ExprId e, Value src) const {
1234 // Obtain the destination type from the cast node.
1235 Type dtp = exp(e).val.getType();
1236 // Inspect source type. For vector types, apply the same
1237 // vectorization to the destination type.
1238 if (auto vtp = dyn_cast<VectorType>(src.getType()))
1239 return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
1240 return dtp;
1241 }
1242
1243 /// Ensures that the sparsifier can generate code for expression.
isAdmissibleBranchExp(Operation * op,Block * block,Value v)1244 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
1245 // Arguments are always admissible.
1246 if (isa<BlockArgument>(v))
1247 return true;
1248 // Accept index anywhere.
1249 Operation *def = v.getDefiningOp();
1250 if (isa<linalg::IndexOp>(def))
1251 return true;
1252 // Operation defined outside branch.
1253 if (def->getBlock() != block)
1254 return def->getBlock() != op->getBlock(); // invariant?
1255 // Operation defined within branch. Anything is accepted,
1256 // as long as all subexpressions are admissible.
1257 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
1258 if (!isAdmissibleBranchExp(op, block, def->getOperand(i)))
1259 return false;
1260 return true;
1261 }
1262
1263 /// Ensures that the sparsifier can generate code for branch.
isAdmissibleBranch(Operation * op,Region & region)1264 static bool isAdmissibleBranch(Operation *op, Region ®ion) {
1265 if (region.empty())
1266 return true;
1267 // Build the semi-ring branch semantics backward from yield.
1268 Operation *yield = region.front().getTerminator();
1269 assert(isa<YieldOp>(yield));
1270 return isAdmissibleBranchExp(op, ®ion.front(), yield->getOperand(0));
1271 }
1272
1273 // Recognizes a direct GT comparison.
isGreater(TensorExp::Kind kind,Attribute attr)1274 static bool isGreater(TensorExp::Kind kind, Attribute attr) {
1275 if (kind == TensorExp::Kind::kCmpI) {
1276 auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr).getValue();
1277 return pred == arith::CmpIPredicate::ugt ||
1278 pred == arith::CmpIPredicate::sgt;
1279 }
1280 if (kind == TensorExp::Kind::kCmpF) {
1281 auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr).getValue();
1282 return pred == arith::CmpFPredicate::UGT ||
1283 pred == arith::CmpFPredicate::OGT;
1284 }
1285 return false;
1286 }
1287
1288 std::pair<std::optional<ExprId>, bool>
buildTensorExp(linalg::GenericOp op,Value v)1289 Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1290 // Recursion leaves.
1291 if (auto arg = dyn_cast<BlockArgument>(v)) {
1292 const TensorId tid = makeTensorId(arg.getArgNumber());
1293 // Any argument of the generic op that is not marked as a scalar
1294 // argument is considered a tensor, indexed by the implicit loop
1295 // bounds. This includes rank-0 tensor arguments.
1296 if (arg.getOwner()->getParentOp() == op) {
1297 OpOperand &t = op->getOpOperand(tid);
1298 bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr;
1299 if (!op.isScalar(&t))
1300 return {addTensorExp(tid), hasSpDep};
1301 v = t.get(); // get scalar value
1302 }
1303 // Any other argument (marked as scalar argument for the generic op
1304 // or belonging to an enveloping op) is considered invariant.
1305 return {addInvariantExp(v), /*hasSpDep=*/false};
1306 }
1307
1308 // Something defined outside is invariant.
1309 Operation *def = v.getDefiningOp();
1310 if (def->getBlock() != &op.getRegion().front())
1311 return {addInvariantExp(v), /*hasSpDep=*/false};
1312 // Construct index operations.
1313 if (def->getNumOperands() == 0) {
1314 if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
1315 return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false};
1316 }
1317
1318 // Construct unary operations if subexpression can be built.
1319 if (def->getNumOperands() == 1) {
1320 const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
1321 if (x.has_value()) {
1322 const ExprId e = *x;
1323 if (isa<math::AbsFOp>(def))
1324 return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep};
1325 if (isa<complex::AbsOp>(def))
1326 return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep};
1327 if (isa<math::AbsIOp>(def))
1328 return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep};
1329 if (isa<math::CeilOp>(def))
1330 return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep};
1331 if (isa<math::FloorOp>(def))
1332 return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep};
1333 if (isa<math::SqrtOp>(def))
1334 return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep};
1335 if (isa<complex::SqrtOp>(def))
1336 return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep};
1337 if (isa<math::ExpM1Op>(def))
1338 return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep};
1339 if (isa<complex::Expm1Op>(def))
1340 return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep};
1341 if (isa<math::Log1pOp>(def))
1342 return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep};
1343 if (isa<complex::Log1pOp>(def))
1344 return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep};
1345 if (isa<math::SinOp>(def))
1346 return {addExp(TensorExp::Kind::kSinF, e), hasSpDep};
1347 if (isa<complex::SinOp>(def))
1348 return {addExp(TensorExp::Kind::kSinC, e), hasSpDep};
1349 if (isa<math::TanhOp>(def))
1350 return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep};
1351 if (isa<complex::TanhOp>(def))
1352 return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep};
1353 if (isa<arith::NegFOp>(def))
1354 return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std
1355 if (isa<complex::NegOp>(def))
1356 return {addExp(TensorExp::Kind::kNegC, e), hasSpDep};
1357 if (isa<arith::TruncFOp>(def))
1358 return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep};
1359 if (isa<arith::ExtFOp>(def))
1360 return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep};
1361 if (isa<arith::FPToSIOp>(def))
1362 return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep};
1363 if (isa<arith::FPToUIOp>(def))
1364 return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep};
1365 if (isa<arith::SIToFPOp>(def))
1366 return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep};
1367 if (isa<arith::UIToFPOp>(def))
1368 return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep};
1369 if (isa<arith::ExtSIOp>(def))
1370 return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep};
1371 if (isa<arith::ExtUIOp>(def))
1372 return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep};
1373 if (isa<arith::IndexCastOp>(def))
1374 return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep};
1375 if (isa<arith::TruncIOp>(def))
1376 return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep};
1377 if (isa<complex::ImOp>(def))
1378 return {addExp(TensorExp::Kind::kCIm, e), hasSpDep};
1379 if (isa<complex::ReOp>(def))
1380 return {addExp(TensorExp::Kind::kCRe, e), hasSpDep};
1381 if (isa<arith::BitcastOp>(def))
1382 return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep};
1383 if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1384 if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
1385 isAdmissibleBranch(unop, unop.getAbsentRegion()))
1386 return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep};
1387 }
1388 if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1389 if (isAdmissibleBranch(selop, selop.getRegion()))
1390 return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep};
1391 }
1392 }
1393 }
1394
1395 // Construct binary operations if subexpressions can be built.
1396 // See buildLattices() for an explanation of rejecting certain
1397 // division and shift operations.
1398 if (def->getNumOperands() == 2) {
1399 const auto [x, xSpVals] = buildTensorExp(op, def->getOperand(0));
1400 const auto [y, ySpVals] = buildTensorExp(op, def->getOperand(1));
1401 // For a conjunctive operation, it yields a "sparse" result if any operand
1402 // is sparse. For a disjunctive operation, it yields a "sparse" result if
1403 // all operands are sparse.
1404 bool conjSpVals = xSpVals || ySpVals;
1405 bool disjSpVals = xSpVals && ySpVals;
1406 if (x.has_value() && y.has_value()) {
1407 const ExprId e0 = *x;
1408 const ExprId e1 = *y;
1409 if (isa<arith::MulFOp>(def))
1410 return {addExp(TensorExp::Kind::kMulF, e0, e1), conjSpVals};
1411 if (isa<complex::MulOp>(def))
1412 return {addExp(TensorExp::Kind::kMulC, e0, e1), conjSpVals};
1413 if (isa<arith::MulIOp>(def))
1414 return {addExp(TensorExp::Kind::kMulI, e0, e1), conjSpVals};
1415 if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1416 return {addExp(TensorExp::Kind::kDivF, e0, e1), conjSpVals};
1417 if (isa<complex::DivOp>(def) && !maybeZero(e1))
1418 return {addExp(TensorExp::Kind::kDivC, e0, e1), conjSpVals};
1419 if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1420 return {addExp(TensorExp::Kind::kDivS, e0, e1), conjSpVals};
1421 if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1422 return {addExp(TensorExp::Kind::kDivU, e0, e1), conjSpVals};
1423 if (isa<arith::AddFOp>(def))
1424 return {addExp(TensorExp::Kind::kAddF, e0, e1), disjSpVals};
1425 if (isa<complex::AddOp>(def))
1426 return {addExp(TensorExp::Kind::kAddC, e0, e1), disjSpVals};
1427 if (isa<arith::AddIOp>(def))
1428 return {addExp(TensorExp::Kind::kAddI, e0, e1), disjSpVals};
1429 if (isa<arith::SubFOp>(def))
1430 return {addExp(TensorExp::Kind::kSubF, e0, e1), disjSpVals};
1431 if (isa<complex::SubOp>(def))
1432 return {addExp(TensorExp::Kind::kSubC, e0, e1), disjSpVals};
1433 if (isa<arith::SubIOp>(def))
1434 return {addExp(TensorExp::Kind::kSubI, e0, e1), disjSpVals};
1435 if (isa<arith::AndIOp>(def))
1436 return {addExp(TensorExp::Kind::kAndI, e0, e1), conjSpVals};
1437 if (isa<arith::OrIOp>(def))
1438 return {addExp(TensorExp::Kind::kOrI, e0, e1), disjSpVals};
1439 if (isa<arith::XOrIOp>(def))
1440 return {addExp(TensorExp::Kind::kXorI, e0, e1), disjSpVals};
1441 if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1442 return {addExp(TensorExp::Kind::kShrS, e0, e1), conjSpVals};
1443 if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1444 return {addExp(TensorExp::Kind::kShrU, e0, e1), conjSpVals};
1445 if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1446 return {addExp(TensorExp::Kind::kShlI, e0, e1), conjSpVals};
1447 if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
1448 if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1449 ci.getPredicate() == arith::CmpIPredicate::sle &&
1450 ci.getPredicate() == arith::CmpIPredicate::sge &&
1451 ci.getPredicate() == arith::CmpIPredicate::ule &&
1452 ci.getPredicate() == arith::CmpIPredicate::uge) {
1453 // We can not sparsify comparison with equal, this is because 0 <= 0
1454 // yields true, and thus densifies the result.
1455 return {std::nullopt, false};
1456 }
1457
1458 auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
1459 ci.getPredicateAttr());
1460 return {e, conjSpVals};
1461 }
1462 if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
1463 if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1464 cf.getPredicate() == arith::CmpFPredicate::OGE &&
1465 cf.getPredicate() == arith::CmpFPredicate::OLE &&
1466 cf.getPredicate() == arith::CmpFPredicate::ONE &&
1467 cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1468 cf.getPredicate() == arith::CmpFPredicate::UGE &&
1469 cf.getPredicate() == arith::CmpFPredicate::ULE &&
1470 cf.getPredicate() == arith::CmpFPredicate::ORD &&
1471 cf.getPredicate() == arith::CmpFPredicate::UNO) {
1472 // We can not sparsify comparison with equal, this is because 0 <= 0
1473 // yields true, and thus densifies the result.
1474 return {std::nullopt, false};
1475 }
1476 auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
1477 cf.getPredicateAttr());
1478 return {e, conjSpVals};
1479 }
1480 if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1481 if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
1482 (binop.getLeftIdentity() ||
1483 isAdmissibleBranch(binop, binop.getLeftRegion())) &&
1484 (binop.getRightIdentity() ||
1485 isAdmissibleBranch(binop, binop.getRightRegion())))
1486 return {addExp(TensorExp::Kind::kBinary, e0, e1, def), conjSpVals};
1487 }
1488 }
1489 }
1490
1491 // Construct ternary operations if subexpressions can be built.
1492 if (def->getNumOperands() == 3) {
1493 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1494 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1495 const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
1496 bool hasSpDep = xDepSp || yDepSp || zDepSp;
1497 if (x.has_value() && y.has_value() && z.has_value()) {
1498 const ExprId e0 = *x;
1499 const ExprId e1 = *y;
1500 if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1501 if (isAdmissibleBranch(redop, redop.getRegion()))
1502 return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
1503 }
1504 if (auto selop = dyn_cast<arith::SelectOp>(def)) {
1505 // Recognize an integral or floating-point ReLu(x) = Max(x, 0)
1506 // operation inside a very specific ternary select operation.
1507 // TODO: capture MIN/MAX/ABS/RELU structure in a more generic way
1508 const auto &cnd = exp(*x);
1509 if (isGreater(cnd.kind, cnd.attr) &&
1510 exp(*y).kind == TensorExp::Kind::kTensor &&
1511 exp(*z).kind == TensorExp::Kind::kInvariant &&
1512 isCertainZero(exp(*z).val)) {
1513 const auto &a = exp(cnd.children.e0);
1514 const auto &b = exp(cnd.children.e1);
1515 if (a.kind == TensorExp::Kind::kTensor &&
1516 a.tensor == exp(*y).tensor &&
1517 b.kind == TensorExp::Kind::kInvariant && isCertainZero(b.val)) {
1518 return {addExp(TensorExp::Kind::kRelu, *y, detail::kInvalidId,
1519 nullptr, cnd.attr),
1520 yDepSp};
1521 }
1522 }
1523 }
1524 }
1525 }
1526
1527 // If we reach here, we are dealing with an operation that is not currently
1528 // sparsifiable. We can still generate code for it if all its operands only
1529 // have dense dependencies (i.e., all the values are loaded from dense
1530 // tensors).
1531 if (def->getNumResults() != 1) // only handle single result operation.
1532 return {std::nullopt, false};
1533 SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
1534 // Builds all the sub-expressions
1535 for (Value operand : def->getOperands())
1536 subExp.push_back(buildTensorExp(op, operand));
1537
1538 if (llvm::all_of(subExp,
1539 [](auto e) { return e.first.has_value() && !e.second; })) {
1540 // All the subexpressions can be built and has *no* sparse dependencies.
1541 if (subExp.size() == 2) {
1542 auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1543 *subExp[1].first, def);
1544 return {e, false};
1545 }
1546 if (subExp.size() == 1) {
1547 auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1548 detail::kInvalidId, def);
1549 return {e, false};
1550 }
1551 }
1552
1553 // Cannot build.
1554 return {std::nullopt, false};
1555 }
1556
insertYieldOp(RewriterBase & rewriter,Location loc,Region & region,ValueRange vals)1557 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion,
1558 ValueRange vals) {
1559 // Make a clone of overlap region.
1560 Region tmpRegion;
1561 IRMapping mapper;
1562 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1563 Block &clonedBlock = tmpRegion.front();
1564 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1565 // Merge cloned block and return yield value.
1566 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1567 rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
1568 Value val = clonedYield.getSingleResult();
1569 rewriter.eraseOp(clonedYield);
1570 rewriter.eraseOp(placeholder);
1571 return val;
1572 }
1573
buildUnaryPresent(RewriterBase & rewriter,Location loc,Operation * op,Value v0)1574 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
1575 Operation *op, Value v0) {
1576 if (!v0)
1577 // Empty input value must be propagated.
1578 return Value();
1579 UnaryOp unop = cast<UnaryOp>(op);
1580 Region &presentRegion = unop.getPresentRegion();
1581 if (presentRegion.empty())
1582 // Uninitialized Value() will be interpreted as missing data in the
1583 // output.
1584 return Value();
1585 return insertYieldOp(rewriter, loc, presentRegion, {v0});
1586 }
1587
buildBinaryOverlap(RewriterBase & rewriter,Location loc,Operation * op,Value v0,Value v1)1588 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
1589 Operation *op, Value v0, Value v1) {
1590 if (!v0 || !v1)
1591 // Empty input values must be propagated.
1592 return Value();
1593 BinaryOp binop = cast<BinaryOp>(op);
1594 Region &overlapRegion = binop.getOverlapRegion();
1595 if (overlapRegion.empty())
1596 // Uninitialized Value() will be interpreted as missing data in the
1597 // output.
1598 return Value();
1599 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1600 }
1601
buildRelu(RewriterBase & rewriter,Location loc,Value v0,Attribute attr)1602 static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0,
1603 Attribute attr) {
1604 Type tp = v0.getType();
1605 auto zero =
1606 rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
1607 Value cmp;
1608 if (isa<FloatType>(tp)) {
1609 auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr);
1610 cmp = rewriter.create<arith::CmpFOp>(loc, pred, v0, zero);
1611 } else {
1612 auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr);
1613 cmp = rewriter.create<arith::CmpIOp>(loc, pred, v0, zero);
1614 }
1615 return rewriter.create<arith::SelectOp>(loc, cmp, v0, zero);
1616 }
1617
buildExp(RewriterBase & rewriter,Location loc,ExprId e,Value v0,Value v1) const1618 Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
1619 Value v1) const {
1620 const auto &expr = exp(e);
1621 switch (expr.kind) {
1622 // Leaf.
1623 case TensorExp::Kind::kTensor:
1624 case TensorExp::Kind::kInvariant:
1625 case TensorExp::Kind::kLoopVar:
1626 case TensorExp::Kind::kSynZero:
1627 llvm_unreachable("unexpected non-op");
1628 // Unary operations.
1629 case TensorExp::Kind::kAbsF:
1630 return rewriter.create<math::AbsFOp>(loc, v0);
1631 case TensorExp::Kind::kAbsC: {
1632 auto type = cast<ComplexType>(v0.getType());
1633 auto eltType = cast<FloatType>(type.getElementType());
1634 return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1635 }
1636 case TensorExp::Kind::kAbsI:
1637 return rewriter.create<math::AbsIOp>(loc, v0);
1638 case TensorExp::Kind::kCeilF:
1639 return rewriter.create<math::CeilOp>(loc, v0);
1640 case TensorExp::Kind::kFloorF:
1641 return rewriter.create<math::FloorOp>(loc, v0);
1642 case TensorExp::Kind::kSqrtF:
1643 return rewriter.create<math::SqrtOp>(loc, v0);
1644 case TensorExp::Kind::kSqrtC:
1645 return rewriter.create<complex::SqrtOp>(loc, v0);
1646 case TensorExp::Kind::kExpm1F:
1647 return rewriter.create<math::ExpM1Op>(loc, v0);
1648 case TensorExp::Kind::kExpm1C:
1649 return rewriter.create<complex::Expm1Op>(loc, v0);
1650 case TensorExp::Kind::kLog1pF:
1651 return rewriter.create<math::Log1pOp>(loc, v0);
1652 case TensorExp::Kind::kLog1pC:
1653 return rewriter.create<complex::Log1pOp>(loc, v0);
1654 case TensorExp::Kind::kRelu:
1655 return buildRelu(rewriter, loc, v0, expr.attr);
1656 case TensorExp::Kind::kSinF:
1657 return rewriter.create<math::SinOp>(loc, v0);
1658 case TensorExp::Kind::kSinC:
1659 return rewriter.create<complex::SinOp>(loc, v0);
1660 case TensorExp::Kind::kTanhF:
1661 return rewriter.create<math::TanhOp>(loc, v0);
1662 case TensorExp::Kind::kTanhC:
1663 return rewriter.create<complex::TanhOp>(loc, v0);
1664 case TensorExp::Kind::kNegF:
1665 return rewriter.create<arith::NegFOp>(loc, v0);
1666 case TensorExp::Kind::kNegC:
1667 return rewriter.create<complex::NegOp>(loc, v0);
1668 case TensorExp::Kind::kNegI: // no negi in std
1669 return rewriter.create<arith::SubIOp>(
1670 loc,
1671 rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1672 rewriter.getZeroAttr(v0.getType())),
1673 v0);
1674 case TensorExp::Kind::kTruncF:
1675 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1676 case TensorExp::Kind::kExtF:
1677 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1678 case TensorExp::Kind::kCastFS:
1679 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1680 case TensorExp::Kind::kCastFU:
1681 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1682 case TensorExp::Kind::kCastSF:
1683 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1684 case TensorExp::Kind::kCastUF:
1685 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1686 case TensorExp::Kind::kCastS:
1687 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1688 case TensorExp::Kind::kCastU:
1689 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1690 case TensorExp::Kind::kCastIdx:
1691 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1692 case TensorExp::Kind::kTruncI:
1693 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1694 case TensorExp::Kind::kCIm: {
1695 auto type = cast<ComplexType>(v0.getType());
1696 auto eltType = cast<FloatType>(type.getElementType());
1697 return rewriter.create<complex::ImOp>(loc, eltType, v0);
1698 }
1699 case TensorExp::Kind::kCRe: {
1700 auto type = cast<ComplexType>(v0.getType());
1701 auto eltType = cast<FloatType>(type.getElementType());
1702 return rewriter.create<complex::ReOp>(loc, eltType, v0);
1703 }
1704 case TensorExp::Kind::kBitCast:
1705 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1706 // Binary operations.
1707 case TensorExp::Kind::kMulF:
1708 return rewriter.create<arith::MulFOp>(loc, v0, v1);
1709 case TensorExp::Kind::kMulC:
1710 return rewriter.create<complex::MulOp>(loc, v0, v1);
1711 case TensorExp::Kind::kMulI:
1712 return rewriter.create<arith::MulIOp>(loc, v0, v1);
1713 case TensorExp::Kind::kDivF:
1714 return rewriter.create<arith::DivFOp>(loc, v0, v1);
1715 case TensorExp::Kind::kDivC:
1716 return rewriter.create<complex::DivOp>(loc, v0, v1);
1717 case TensorExp::Kind::kDivS:
1718 return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1719 case TensorExp::Kind::kDivU:
1720 return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1721 case TensorExp::Kind::kAddF:
1722 return rewriter.create<arith::AddFOp>(loc, v0, v1);
1723 case TensorExp::Kind::kAddC:
1724 return rewriter.create<complex::AddOp>(loc, v0, v1);
1725 case TensorExp::Kind::kAddI:
1726 return rewriter.create<arith::AddIOp>(loc, v0, v1);
1727 case TensorExp::Kind::kSubF:
1728 return rewriter.create<arith::SubFOp>(loc, v0, v1);
1729 case TensorExp::Kind::kSubC:
1730 return rewriter.create<complex::SubOp>(loc, v0, v1);
1731 case TensorExp::Kind::kSubI:
1732 return rewriter.create<arith::SubIOp>(loc, v0, v1);
1733 case TensorExp::Kind::kAndI:
1734 return rewriter.create<arith::AndIOp>(loc, v0, v1);
1735 case TensorExp::Kind::kOrI:
1736 return rewriter.create<arith::OrIOp>(loc, v0, v1);
1737 case TensorExp::Kind::kXorI:
1738 return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1739 case TensorExp::Kind::kShrS:
1740 return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1741 case TensorExp::Kind::kShrU:
1742 return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1743 case TensorExp::Kind::kShlI:
1744 return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1745 case TensorExp::Kind::kCmpI: {
1746 auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr);
1747 return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1);
1748 }
1749 case TensorExp::Kind::kCmpF: {
1750 auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr);
1751 return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1);
1752 }
1753 case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
1754 return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(),
1755 {v0});
1756 case TensorExp::Kind::kUnary:
1757 return buildUnaryPresent(rewriter, loc, expr.op, v0);
1758 case TensorExp::Kind::kSelect:
1759 return insertYieldOp(rewriter, loc,
1760 cast<sparse_tensor::SelectOp>(expr.op).getRegion(),
1761 {v0});
1762 case TensorExp::Kind::kBinary:
1763 return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
1764 case TensorExp::Kind::kReduce: {
1765 ReduceOp redOp = cast<ReduceOp>(expr.op);
1766 return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1767 }
1768 case TensorExp::Kind::kDenseOp: {
1769 Operation *actualOp = expr.op;
1770 IRMapping mapping;
1771 mapping.map(actualOp->getOperand(0), v0);
1772 if (actualOp->getNumOperands() == 2)
1773 mapping.map(actualOp->getOperand(1), v1);
1774 return rewriter.clone(*actualOp, mapping)->getResult(0);
1775 }
1776 }
1777 llvm_unreachable("unexpected expression kind in build");
1778 }
1779
1780 } // namespace sparse_tensor
1781 } // namespace mlir
1782