1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arith/IR/Arith.h" 11 #include "mlir/Dialect/Complex/IR/Complex.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 15 #include "mlir/IR/Operation.h" 16 #include "llvm/Support/Debug.h" 17 #include <optional> 18 19 namespace mlir { 20 namespace sparse_tensor { 21 22 enum class ExpArity { 23 kNullary, 24 kUnary, 25 kBinary, 26 }; 27 28 static ExpArity getExpArity(TensorExp::Kind k) { 29 switch (k) { 30 // Leaf. 31 case TensorExp::Kind::kTensor: 32 case TensorExp::Kind::kInvariant: 33 case TensorExp::Kind::kLoopVar: 34 case TensorExp::Kind::kSynZero: 35 return ExpArity::kNullary; 36 case TensorExp::Kind::kAbsF: 37 case TensorExp::Kind::kAbsC: 38 case TensorExp::Kind::kAbsI: 39 case TensorExp::Kind::kCeilF: 40 case TensorExp::Kind::kFloorF: 41 case TensorExp::Kind::kSqrtF: 42 case TensorExp::Kind::kSqrtC: 43 case TensorExp::Kind::kExpm1F: 44 case TensorExp::Kind::kExpm1C: 45 case TensorExp::Kind::kLog1pF: 46 case TensorExp::Kind::kLog1pC: 47 case TensorExp::Kind::kSinF: 48 case TensorExp::Kind::kSinC: 49 case TensorExp::Kind::kTanhF: 50 case TensorExp::Kind::kTanhC: 51 case TensorExp::Kind::kTruncF: 52 case TensorExp::Kind::kExtF: 53 case TensorExp::Kind::kCastFS: 54 case TensorExp::Kind::kCastFU: 55 case TensorExp::Kind::kCastSF: 56 case TensorExp::Kind::kCastUF: 57 case TensorExp::Kind::kCastS: 58 case TensorExp::Kind::kCastU: 59 case TensorExp::Kind::kCastIdx: 60 case TensorExp::Kind::kTruncI: 61 case TensorExp::Kind::kCIm: 62 case TensorExp::Kind::kCRe: 63 case TensorExp::Kind::kBitCast: 64 case TensorExp::Kind::kBinaryBranch: 65 case TensorExp::Kind::kUnary: 66 case TensorExp::Kind::kSelect: 67 case TensorExp::Kind::kNegF: 68 case TensorExp::Kind::kNegC: 69 case TensorExp::Kind::kNegI: 70 return ExpArity::kUnary; 71 // Binary operations. 72 case TensorExp::Kind::kDivF: 73 case TensorExp::Kind::kDivC: 74 case TensorExp::Kind::kDivS: 75 case TensorExp::Kind::kDivU: 76 case TensorExp::Kind::kShrS: 77 case TensorExp::Kind::kShrU: 78 case TensorExp::Kind::kShlI: 79 case TensorExp::Kind::kMulF: 80 case TensorExp::Kind::kMulC: 81 case TensorExp::Kind::kMulI: 82 case TensorExp::Kind::kAndI: 83 case TensorExp::Kind::kAddF: 84 case TensorExp::Kind::kAddC: 85 case TensorExp::Kind::kAddI: 86 case TensorExp::Kind::kOrI: 87 case TensorExp::Kind::kXorI: 88 case TensorExp::Kind::kBinary: 89 case TensorExp::Kind::kReduce: 90 case TensorExp::Kind::kSubF: 91 case TensorExp::Kind::kSubC: 92 case TensorExp::Kind::kSubI: 93 case TensorExp::Kind::kCmpF: 94 case TensorExp::Kind::kCmpI: 95 case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands 96 return ExpArity::kBinary; 97 } 98 llvm_unreachable("unexpected kind"); 99 } 100 101 //===----------------------------------------------------------------------===// 102 // Constructors. 103 //===----------------------------------------------------------------------===// 104 105 TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v, 106 Operation *o, Attribute a) 107 : kind(k), val(v), op(o) { 108 switch (kind) { 109 // Leaf. 110 case TensorExp::Kind::kTensor: 111 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 112 tensor = x; 113 return; 114 case TensorExp::Kind::kSynZero: 115 assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o); 116 return; 117 case TensorExp::Kind::kInvariant: 118 assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o); 119 return; 120 case TensorExp::Kind::kLoopVar: 121 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 122 loop = x; 123 return; 124 // Unary operations. 125 case TensorExp::Kind::kAbsF: 126 case TensorExp::Kind::kAbsC: 127 case TensorExp::Kind::kAbsI: 128 case TensorExp::Kind::kCeilF: 129 case TensorExp::Kind::kFloorF: 130 case TensorExp::Kind::kSqrtF: 131 case TensorExp::Kind::kSqrtC: 132 case TensorExp::Kind::kExpm1F: 133 case TensorExp::Kind::kExpm1C: 134 case TensorExp::Kind::kLog1pF: 135 case TensorExp::Kind::kLog1pC: 136 case TensorExp::Kind::kSinF: 137 case TensorExp::Kind::kSinC: 138 case TensorExp::Kind::kTanhF: 139 case TensorExp::Kind::kTanhC: 140 case TensorExp::Kind::kNegF: 141 case TensorExp::Kind::kNegC: 142 case TensorExp::Kind::kNegI: 143 case TensorExp::Kind::kCIm: 144 case TensorExp::Kind::kCRe: 145 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 146 children.e0 = x; 147 children.e1 = y; 148 return; 149 case TensorExp::Kind::kTruncF: 150 case TensorExp::Kind::kExtF: 151 case TensorExp::Kind::kCastFS: 152 case TensorExp::Kind::kCastFU: 153 case TensorExp::Kind::kCastSF: 154 case TensorExp::Kind::kCastUF: 155 case TensorExp::Kind::kCastS: 156 case TensorExp::Kind::kCastU: 157 case TensorExp::Kind::kCastIdx: 158 case TensorExp::Kind::kTruncI: 159 case TensorExp::Kind::kBitCast: 160 assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o); 161 children.e0 = x; 162 children.e1 = y; 163 return; 164 case TensorExp::Kind::kBinaryBranch: 165 case TensorExp::Kind::kSelect: 166 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o); 167 children.e0 = x; 168 children.e1 = y; 169 return; 170 case TensorExp::Kind::kUnary: 171 // No assertion on y can be made, as the branching paths involve both 172 // a unary (`mapSet`) and binary (`disjSet`) pathway. 173 assert(x != detail::kInvalidId && !v && o); 174 children.e0 = x; 175 children.e1 = y; 176 return; 177 // Binary operations. 178 case TensorExp::Kind::kMulF: 179 case TensorExp::Kind::kMulC: 180 case TensorExp::Kind::kMulI: 181 case TensorExp::Kind::kDivF: 182 case TensorExp::Kind::kDivC: 183 case TensorExp::Kind::kDivS: 184 case TensorExp::Kind::kDivU: 185 case TensorExp::Kind::kAddF: 186 case TensorExp::Kind::kAddC: 187 case TensorExp::Kind::kAddI: 188 case TensorExp::Kind::kSubF: 189 case TensorExp::Kind::kSubC: 190 case TensorExp::Kind::kSubI: 191 case TensorExp::Kind::kAndI: 192 case TensorExp::Kind::kOrI: 193 case TensorExp::Kind::kXorI: 194 case TensorExp::Kind::kShrS: 195 case TensorExp::Kind::kShrU: 196 case TensorExp::Kind::kShlI: 197 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); 198 children.e0 = x; 199 children.e1 = y; 200 return; 201 case TensorExp::Kind::kCmpF: 202 case TensorExp::Kind::kCmpI: 203 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); 204 attr = a; 205 children.e0 = x; 206 children.e1 = y; 207 return; 208 case TensorExp::Kind::kBinary: 209 case TensorExp::Kind::kReduce: 210 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o); 211 children.e0 = x; 212 children.e1 = y; 213 return; 214 case TensorExp::Kind::kDenseOp: 215 assert(x != detail::kInvalidId && !v && o); 216 children.e0 = x; 217 children.e1 = y; 218 return; 219 } 220 llvm_unreachable("unexpected kind"); 221 } 222 223 Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops, 224 unsigned maxLvlRank) 225 : outTensor(numInputOutputTensors - 1), 226 syntheticTensor(numInputOutputTensors), 227 numTensors(numInputOutputTensors + 1), numLoops(numLoops), 228 hasSparseOut(false), 229 lvlTypes(numTensors, 230 std::vector<LevelType>(numLoops, LevelFormat::Undef)), 231 loopToLvl(numTensors, 232 std::vector<std::optional<Level>>(numLoops, std::nullopt)), 233 lvlToLoop(numTensors, 234 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)), 235 loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>( 236 numTensors, std::nullopt)), 237 levelToDependentLoop(numTensors, 238 std::vector<std::vector<LoopCoeffPair>>( 239 maxLvlRank, std::vector<LoopCoeffPair>())), 240 loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {} 241 242 //===----------------------------------------------------------------------===// 243 // Lattice methods. 244 //===----------------------------------------------------------------------===// 245 246 ExprId Merger::addTensorExp(TensorId t) { 247 assert(isValidTensorId(t)); 248 const ExprId eNew(tensorExps.size()); 249 tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId, 250 Value(), nullptr, nullptr); 251 return eNew; 252 } 253 254 ExprId Merger::addLoopVarExp(LoopId i) { 255 assert(isValidLoopId(i)); 256 const ExprId eNew(tensorExps.size()); 257 tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId, 258 Value(), nullptr, nullptr); 259 return eNew; 260 } 261 262 ExprId Merger::addInvariantExp(Value v) { 263 const ExprId eNew(tensorExps.size()); 264 tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId, 265 detail::kInvalidId, v, nullptr, nullptr); 266 return eNew; 267 } 268 269 ExprId Merger::addSynZeroExp() { 270 const ExprId eNew(tensorExps.size()); 271 tensorExps.emplace_back(TensorExp::Kind::kSynZero, detail::kInvalidId, 272 detail::kInvalidId, Value(), nullptr, nullptr); 273 return eNew; 274 } 275 276 ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op, 277 Attribute attr) { 278 assert(k > TensorExp::Kind::kLoopVar); 279 const ExprId eNew(tensorExps.size()); 280 tensorExps.emplace_back(k, e0, e1, Value(), op, attr); 281 return eNew; 282 } 283 284 ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op, 285 Attribute attr) { 286 assert(k > TensorExp::Kind::kLoopVar); 287 const ExprId eNew(tensorExps.size()); 288 tensorExps.emplace_back(k, e, detail::kInvalidId, v, op, attr); 289 return eNew; 290 } 291 292 LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) { 293 const LatPointId pNew(latPoints.size()); 294 const unsigned size = numLoops * numTensors; 295 const TensorLoopId b = makeTensorLoopId(t, i); 296 latPoints.emplace_back(size, e); 297 latPoints[pNew].bits.set(b); 298 return pNew; 299 } 300 301 LatPointId Merger::addLat(const BitVector &bits, ExprId e) { 302 assert(bits.size() == numLoops * numTensors); 303 const LatPointId pNew(latPoints.size()); 304 latPoints.emplace_back(bits, e); 305 return pNew; 306 } 307 308 LatSetId Merger::addSet() { 309 const LatSetId sNew(latSets.size()); 310 latSets.emplace_back(); 311 return sNew; 312 } 313 314 LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1, 315 Operation *op) { 316 TensorExp::Kind kind = exp(e).kind; 317 Attribute attr = exp(e).attr; 318 const LatPointId pNew(latPoints.size()); 319 const auto &point0 = lat(p0); 320 const auto &point1 = lat(p1); 321 BitVector bits(point0.bits); 322 bits |= point1.bits; 323 const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr); 324 latPoints.emplace_back(bits, ne); 325 return pNew; 326 } 327 328 LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { 329 const LatSetId sNew = addSet(); 330 auto &setNew = latSets[sNew]; 331 for (const LatPointId p0 : set(s0)) 332 for (const LatPointId p1 : set(s1)) 333 setNew.push_back(conjLat(e, p0, p1, op)); 334 return sNew; 335 } 336 337 LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { 338 const LatSetId sNew = conjSet(e, s0, s1, op); 339 TensorExp::Kind kind = exp(e).kind; 340 341 // Followed by all in s0. 342 latSets[sNew].append(latSets[s0]); 343 // Map binary 0-y to unary -y. 344 // TODO: move this if-else logic into buildLattices 345 if (kind == TensorExp::Kind::kSubF) 346 s1 = mapSet(TensorExp::Kind::kNegF, s1); 347 else if (kind == TensorExp::Kind::kSubC) 348 s1 = mapSet(TensorExp::Kind::kNegC, s1); 349 else if (kind == TensorExp::Kind::kSubI) 350 s1 = mapSet(TensorExp::Kind::kNegI, s1); 351 // Followed by all in s1. 352 latSets[sNew].append(latSets[s1]); 353 return sNew; 354 } 355 356 LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) { 357 assert(exp(e).kind == TensorExp::Kind::kCmpI || 358 exp(e).kind == TensorExp::Kind::kCmpF); 359 const LatSetId sNew = conjSet(e, s0, s1, nullptr); 360 361 ExprId e0 = exp(e).children.e0; 362 ExprId e1 = exp(e).children.e1; 363 if (exp(e0).kind == TensorExp::Kind::kSynZero || 364 exp(e1).kind == TensorExp::Kind::kSynZero) { 365 // lhs and rhs can't be synthetic zero at the same time. 366 assert(exp(e0).kind != exp(e1).kind); 367 // If one of the operands has already been assigned to zero (the 368 // element is absent in the corresponding operand), then we do not 369 // need to build disjunctive set for it. 370 return sNew; 371 } 372 373 auto lhsSet = mapBinWithSynZeroSet(e, s0, false); 374 auto rhsSet = mapBinWithSynZeroSet(e, s1, true); 375 latSets[sNew].append(latSets[lhsSet]); 376 latSets[sNew].append(latSets[rhsSet]); 377 return sNew; 378 } 379 380 LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, 381 bool includeLeft, TensorExp::Kind ltrans, 382 Operation *opleft, bool includeRight, 383 TensorExp::Kind rtrans, Operation *opright) { 384 const LatSetId sNew = conjSet(e, s0, s1, orig); 385 // Left Region. 386 if (includeLeft) { 387 if (opleft) 388 s0 = mapSet(ltrans, s0, Value(), opleft); 389 latSets[sNew].append(latSets[s0]); 390 } 391 // Right Region. 392 if (includeRight) { 393 if (opright) 394 s1 = mapSet(rtrans, s1, Value(), opright); 395 latSets[sNew].append(latSets[s1]); 396 } 397 return sNew; 398 } 399 400 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v, 401 Operation *op) { 402 assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) || 403 TensorExp::Kind::kDenseOp == kind); 404 const LatSetId sNew = addSet(); 405 auto &setNew = latSets[sNew]; 406 for (const LatPointId p : set(s0)) { 407 const auto &point = latPoints[p]; 408 setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op))); 409 } 410 return sNew; 411 } 412 413 LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) { 414 TensorExp::Kind kind = exp(e).kind; 415 Attribute a = exp(e).attr; 416 assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI); 417 // Must be a binary operation. 418 const LatSetId sNew = addSet(); 419 auto &setNew = latSets[sNew]; 420 const ExprId zeroExp = addSynZeroExp(); 421 for (const LatPointId p : set(s0)) { 422 const auto &point = latPoints[p]; 423 ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a) 424 : addExp(kind, point.exp, zeroExp, nullptr, a); 425 setNew.push_back(addLat(point.bits, newExp)); 426 } 427 return sNew; 428 } 429 430 LatSetId Merger::optimizeSet(LatSetId s0) { 431 const LatSetId sNew = addSet(); 432 auto &setNew = latSets[sNew]; 433 const auto &set0 = set(s0); 434 assert(!set0.empty()); 435 const LatPointId p0 = set0[0]; 436 for (const LatPointId p1 : set0) { 437 bool add = true; 438 if (p0 != p1) { 439 // Check whether this is a straightforward copy. 440 if (expIsTensor(latPoints[p1].exp, outTensor)) 441 continue; 442 // Check whether this conjunction is already covered. 443 for (const LatPointId p2 : setNew) { 444 assert(!latGT(p1, p2)); // Lj => Li would be bad 445 if (onlyDenseDiff(p2, p1)) { 446 add = false; 447 break; 448 } 449 } 450 assert(!add || latGT(p0, p1)); 451 } 452 if (add) 453 setNew.push_back(p1); 454 } 455 for (const LatPointId p : setNew) 456 latPoints[p].simple = simplifyCond(sNew, p); 457 return sNew; 458 } 459 460 BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { 461 // First determine if this lattice point is a *singleton*, i.e., 462 // the last point in a lattice, no other is less than this one. 463 bool isSingleton = true; 464 for (const LatPointId p1 : set(s0)) { 465 if (p0 != p1 && latGT(p0, p1)) { 466 isSingleton = false; 467 break; 468 } 469 } 470 471 BitVector simple(latPoints[p0].bits); 472 bool reset = isSingleton && hasAnySparse(simple); 473 const TensorLoopId be = simple.size(); 474 TensorLoopId offset = 0; // relative to the end 475 if (!reset) 476 // Starts resetting from a dense level, so that the first bit (if kept) 477 // is not undefined level-type. 478 for (unsigned b = 0; b < be; b++) { 479 if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) { 480 offset = be - b - 1; // relative to the end 481 break; 482 } 483 } 484 485 // Now apply the two basic rules. We also iterate the bits reversely to always 486 // keep the rightmost bit (which could possibly be a synthetic tensor). 487 for (unsigned b = be - 1 - offset, i = 0; i < be; 488 b = b == 0 ? be - 1 : b - 1, i++) { 489 // Slice on dense level has `locate` property as well, and can be optimized. 490 if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) { 491 const auto lt = getLvlType(b); 492 if (!lt.hasSparseSemantic()) { 493 if (reset) 494 simple.reset(b); 495 reset = true; 496 } 497 } 498 } 499 return simple; 500 } 501 502 bool Merger::latGT(LatPointId i, LatPointId j) const { 503 const BitVector &bitsi = lat(i).bits; 504 const BitVector &bitsj = lat(j).bits; 505 assert(bitsi.size() == bitsj.size()); 506 if (bitsi.count() > bitsj.count()) { 507 for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) 508 if (bitsj[b] && !bitsi[b]) 509 return false; 510 return true; 511 } 512 return false; 513 } 514 515 bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { 516 BitVector tmp(latPoints[j].bits); 517 tmp ^= latPoints[i].bits; 518 return !hasAnySparse(tmp); 519 } 520 521 bool Merger::expContainsTensor(ExprId e, TensorId t) const { 522 const auto &expr = exp(e); 523 // First we check `expIsTensor`. 524 if (expr.kind == TensorExp::Kind::kTensor) 525 return expr.tensor == t; 526 527 switch (getExpArity(expr.kind)) { 528 case ExpArity::kNullary: 529 return false; 530 case ExpArity::kUnary: { 531 const ExprId e0 = expr.children.e0; 532 return expContainsTensor(e0, t); 533 } 534 case ExpArity::kBinary: { 535 const ExprId e0 = expr.children.e0; 536 const ExprId e1 = expr.children.e1; 537 return expContainsTensor(e0, t) || expContainsTensor(e1, t); 538 } 539 } 540 llvm_unreachable("unexpected arity"); 541 } 542 543 bool Merger::hasNegateOnOut(ExprId e) const { 544 const auto &expr = exp(e); 545 switch (expr.kind) { 546 case TensorExp::Kind::kNegF: 547 case TensorExp::Kind::kNegC: 548 case TensorExp::Kind::kNegI: 549 return expContainsTensor(expr.children.e0, outTensor); 550 case TensorExp::Kind::kSubF: 551 case TensorExp::Kind::kSubC: 552 case TensorExp::Kind::kSubI: 553 return expContainsTensor(expr.children.e1, outTensor) || 554 hasNegateOnOut(expr.children.e0); 555 case TensorExp::Kind::kDenseOp: { 556 bool lhsNeg = hasNegateOnOut(expr.children.e0); 557 if (!lhsNeg && expr.children.e1 != detail::kInvalidId) 558 return hasNegateOnOut(expr.children.e1); 559 return lhsNeg; 560 } 561 default: { 562 switch (getExpArity(expr.kind)) { 563 case ExpArity::kNullary: 564 return false; 565 case ExpArity::kUnary: 566 return hasNegateOnOut(expr.children.e0); 567 case ExpArity::kBinary: 568 return hasNegateOnOut(expr.children.e0) || 569 hasNegateOnOut(expr.children.e1); 570 } 571 } 572 } 573 llvm_unreachable("unexpected kind"); 574 } 575 576 bool Merger::isSingleCondition(TensorId t, ExprId e) const { 577 assert(isValidTensorId(t)); 578 const auto &expr = exp(e); 579 switch (expr.kind) { 580 // Leaf. 581 case TensorExp::Kind::kTensor: 582 return expr.tensor == t; 583 case TensorExp::Kind::kInvariant: 584 case TensorExp::Kind::kLoopVar: 585 case TensorExp::Kind::kSynZero: 586 return false; 587 // Unary operations. 588 case TensorExp::Kind::kAbsF: 589 case TensorExp::Kind::kAbsC: 590 case TensorExp::Kind::kAbsI: 591 case TensorExp::Kind::kCeilF: 592 case TensorExp::Kind::kFloorF: 593 case TensorExp::Kind::kSqrtF: 594 case TensorExp::Kind::kSqrtC: 595 case TensorExp::Kind::kExpm1F: 596 case TensorExp::Kind::kExpm1C: 597 case TensorExp::Kind::kLog1pF: 598 case TensorExp::Kind::kLog1pC: 599 case TensorExp::Kind::kSinF: 600 case TensorExp::Kind::kSinC: 601 case TensorExp::Kind::kTanhF: 602 case TensorExp::Kind::kTanhC: 603 case TensorExp::Kind::kNegF: 604 case TensorExp::Kind::kNegC: 605 case TensorExp::Kind::kNegI: 606 case TensorExp::Kind::kTruncF: 607 case TensorExp::Kind::kExtF: 608 case TensorExp::Kind::kCastFS: 609 case TensorExp::Kind::kCastFU: 610 case TensorExp::Kind::kCastSF: 611 case TensorExp::Kind::kCastUF: 612 case TensorExp::Kind::kCastS: 613 case TensorExp::Kind::kCastU: 614 case TensorExp::Kind::kCastIdx: 615 case TensorExp::Kind::kTruncI: 616 case TensorExp::Kind::kCIm: 617 case TensorExp::Kind::kCRe: 618 case TensorExp::Kind::kBitCast: 619 case TensorExp::Kind::kUnary: 620 return isSingleCondition(t, expr.children.e0); 621 case TensorExp::Kind::kBinaryBranch: 622 case TensorExp::Kind::kSelect: 623 return false; 624 // Binary operations. 625 case TensorExp::Kind::kDivF: // note: x / c only 626 case TensorExp::Kind::kDivC: 627 case TensorExp::Kind::kDivS: 628 case TensorExp::Kind::kDivU: 629 assert(!maybeZero(expr.children.e1)); 630 return isSingleCondition(t, expr.children.e0); 631 case TensorExp::Kind::kShrS: // note: x >> inv only 632 case TensorExp::Kind::kShrU: 633 case TensorExp::Kind::kShlI: 634 assert(isInvariant(expr.children.e1)); 635 return isSingleCondition(t, expr.children.e0); 636 case TensorExp::Kind::kMulF: 637 case TensorExp::Kind::kMulC: 638 case TensorExp::Kind::kMulI: 639 case TensorExp::Kind::kAndI: 640 case TensorExp::Kind::kReduce: 641 if (isSingleCondition(t, expr.children.e0)) 642 return isSingleCondition(t, expr.children.e1) || 643 isInvariant(expr.children.e1); 644 if (isSingleCondition(t, expr.children.e1)) 645 return isInvariant(expr.children.e0); 646 return false; 647 case TensorExp::Kind::kAddF: 648 case TensorExp::Kind::kAddC: 649 case TensorExp::Kind::kAddI: 650 return isSingleCondition(t, expr.children.e0) && 651 isSingleCondition(t, expr.children.e1); 652 case TensorExp::Kind::kSubF: 653 case TensorExp::Kind::kSubC: 654 case TensorExp::Kind::kSubI: 655 case TensorExp::Kind::kOrI: 656 case TensorExp::Kind::kXorI: 657 case TensorExp::Kind::kCmpF: 658 case TensorExp::Kind::kCmpI: 659 case TensorExp::Kind::kBinary: 660 return false; 661 case TensorExp::Kind::kDenseOp: 662 // Since Merger guarantees all the operands of the kDenseOp to be dense, the 663 // operation must be single-condition. 664 return true; 665 } 666 llvm_unreachable("unexpected kind"); 667 } 668 669 bool Merger::hasAnySparse(const BitVector &bits) const { 670 for (TensorLoopId b : bits.set_bits()) { 671 const auto lt = getLvlType(b); 672 if (lt.hasSparseSemantic()) 673 return true; 674 } 675 return hasSparseIdxReduction(bits); 676 } 677 678 bool Merger::hasSparseIdxReduction(const BitVector &bits) const { 679 for (TensorLoopId b : bits.set_bits()) 680 if (isSparseLvlWithNonTrivialIdxExp(b)) 681 return true; 682 return false; 683 } 684 685 #ifndef NDEBUG 686 687 //===----------------------------------------------------------------------===// 688 // Print methods (for debugging). 689 //===----------------------------------------------------------------------===// 690 691 static const char *kindToOpSymbol(TensorExp::Kind kind) { 692 switch (kind) { 693 // Leaf. 694 case TensorExp::Kind::kTensor: 695 return "tensor"; 696 case TensorExp::Kind::kInvariant: 697 return "invariant"; 698 case TensorExp::Kind::kLoopVar: 699 return "index"; 700 case TensorExp::Kind::kSynZero: 701 return "0"; 702 // Unary operations. 703 case TensorExp::Kind::kAbsF: 704 case TensorExp::Kind::kAbsC: 705 case TensorExp::Kind::kAbsI: 706 return "abs"; 707 case TensorExp::Kind::kCeilF: 708 return "ceil"; 709 case TensorExp::Kind::kFloorF: 710 return "floor"; 711 case TensorExp::Kind::kSqrtF: 712 case TensorExp::Kind::kSqrtC: 713 return "sqrt"; 714 case TensorExp::Kind::kExpm1F: 715 case TensorExp::Kind::kExpm1C: 716 return "expm1"; 717 case TensorExp::Kind::kLog1pF: 718 case TensorExp::Kind::kLog1pC: 719 return "log1p"; 720 case TensorExp::Kind::kSinF: 721 case TensorExp::Kind::kSinC: 722 return "sin"; 723 case TensorExp::Kind::kTanhF: 724 case TensorExp::Kind::kTanhC: 725 return "tanh"; 726 case TensorExp::Kind::kNegF: 727 case TensorExp::Kind::kNegC: 728 case TensorExp::Kind::kNegI: 729 return "-"; 730 case TensorExp::Kind::kTruncF: 731 case TensorExp::Kind::kExtF: 732 case TensorExp::Kind::kCastFS: 733 case TensorExp::Kind::kCastFU: 734 case TensorExp::Kind::kCastSF: 735 case TensorExp::Kind::kCastUF: 736 case TensorExp::Kind::kCastS: 737 case TensorExp::Kind::kCastU: 738 case TensorExp::Kind::kCastIdx: 739 case TensorExp::Kind::kTruncI: 740 case TensorExp::Kind::kCIm: 741 return "complex.im"; 742 case TensorExp::Kind::kCRe: 743 return "complex.re"; 744 case TensorExp::Kind::kBitCast: 745 return "cast"; 746 case TensorExp::Kind::kBinaryBranch: 747 return "binary_branch"; 748 case TensorExp::Kind::kUnary: 749 return "unary"; 750 case TensorExp::Kind::kSelect: 751 return "select"; 752 // Binary operations. 753 case TensorExp::Kind::kMulF: 754 case TensorExp::Kind::kMulC: 755 case TensorExp::Kind::kMulI: 756 return "*"; 757 case TensorExp::Kind::kDivF: 758 case TensorExp::Kind::kDivC: 759 case TensorExp::Kind::kDivS: 760 case TensorExp::Kind::kDivU: 761 return "/"; 762 case TensorExp::Kind::kAddF: 763 case TensorExp::Kind::kAddC: 764 case TensorExp::Kind::kAddI: 765 return "+"; 766 case TensorExp::Kind::kSubF: 767 case TensorExp::Kind::kSubC: 768 case TensorExp::Kind::kSubI: 769 return "-"; 770 case TensorExp::Kind::kAndI: 771 return "&"; 772 case TensorExp::Kind::kOrI: 773 return "|"; 774 case TensorExp::Kind::kXorI: 775 return "^"; 776 case TensorExp::Kind::kShrS: 777 return "a>>"; 778 case TensorExp::Kind::kShrU: 779 return ">>"; 780 case TensorExp::Kind::kShlI: 781 return "<<"; 782 case TensorExp::Kind::kCmpF: 783 case TensorExp::Kind::kCmpI: 784 return "cmp"; 785 case TensorExp::Kind::kBinary: 786 return "binary"; 787 case TensorExp::Kind::kReduce: 788 return "reduce"; 789 case TensorExp::Kind::kDenseOp: 790 return "dense"; 791 } 792 llvm_unreachable("unexpected kind for symbol"); 793 } 794 795 void Merger::dumpExp(ExprId e) const { 796 const auto &expr = exp(e); 797 switch (expr.kind) { 798 // Leaf. 799 case TensorExp::Kind::kTensor: 800 if (expr.tensor == syntheticTensor) 801 llvm::dbgs() << "synthetic_"; 802 else if (expr.tensor == outTensor) 803 llvm::dbgs() << "output_"; 804 llvm::dbgs() << "tensor_" << expr.tensor; 805 break; 806 case TensorExp::Kind::kInvariant: 807 llvm::dbgs() << "invariant"; 808 break; 809 case TensorExp::Kind::kSynZero: 810 llvm::dbgs() << "0"; 811 break; 812 case TensorExp::Kind::kLoopVar: 813 llvm::dbgs() << "loopvar_" << expr.loop; 814 break; 815 // Unary operations. 816 case TensorExp::Kind::kAbsF: 817 case TensorExp::Kind::kAbsC: 818 case TensorExp::Kind::kAbsI: 819 case TensorExp::Kind::kCeilF: 820 case TensorExp::Kind::kFloorF: 821 case TensorExp::Kind::kSqrtF: 822 case TensorExp::Kind::kSqrtC: 823 case TensorExp::Kind::kExpm1F: 824 case TensorExp::Kind::kExpm1C: 825 case TensorExp::Kind::kLog1pF: 826 case TensorExp::Kind::kLog1pC: 827 case TensorExp::Kind::kSinF: 828 case TensorExp::Kind::kSinC: 829 case TensorExp::Kind::kTanhF: 830 case TensorExp::Kind::kTanhC: 831 case TensorExp::Kind::kNegF: 832 case TensorExp::Kind::kNegC: 833 case TensorExp::Kind::kNegI: 834 case TensorExp::Kind::kTruncF: 835 case TensorExp::Kind::kExtF: 836 case TensorExp::Kind::kCastFS: 837 case TensorExp::Kind::kCastFU: 838 case TensorExp::Kind::kCastSF: 839 case TensorExp::Kind::kCastUF: 840 case TensorExp::Kind::kCastS: 841 case TensorExp::Kind::kCastU: 842 case TensorExp::Kind::kCastIdx: 843 case TensorExp::Kind::kTruncI: 844 case TensorExp::Kind::kCIm: 845 case TensorExp::Kind::kCRe: 846 case TensorExp::Kind::kBitCast: 847 case TensorExp::Kind::kBinaryBranch: 848 case TensorExp::Kind::kUnary: 849 case TensorExp::Kind::kSelect: 850 llvm::dbgs() << kindToOpSymbol(expr.kind) << " "; 851 dumpExp(expr.children.e0); 852 break; 853 // Binary operations. 854 case TensorExp::Kind::kMulF: 855 case TensorExp::Kind::kMulC: 856 case TensorExp::Kind::kMulI: 857 case TensorExp::Kind::kDivF: 858 case TensorExp::Kind::kDivC: 859 case TensorExp::Kind::kDivS: 860 case TensorExp::Kind::kDivU: 861 case TensorExp::Kind::kAddF: 862 case TensorExp::Kind::kAddC: 863 case TensorExp::Kind::kAddI: 864 case TensorExp::Kind::kSubF: 865 case TensorExp::Kind::kSubC: 866 case TensorExp::Kind::kSubI: 867 case TensorExp::Kind::kAndI: 868 case TensorExp::Kind::kOrI: 869 case TensorExp::Kind::kXorI: 870 case TensorExp::Kind::kShrS: 871 case TensorExp::Kind::kShrU: 872 case TensorExp::Kind::kShlI: 873 case TensorExp::Kind::kCmpF: 874 case TensorExp::Kind::kCmpI: 875 case TensorExp::Kind::kBinary: 876 case TensorExp::Kind::kReduce: 877 case TensorExp::Kind::kDenseOp: 878 llvm::dbgs() << "("; 879 dumpExp(expr.children.e0); 880 llvm::dbgs() << " " << kindToOpSymbol(expr.kind); 881 if (expr.attr) 882 llvm::dbgs() << "{" << expr.attr << "}"; 883 if (expr.children.e1 != detail::kInvalidId) { 884 llvm::dbgs() << " "; 885 dumpExp(expr.children.e1); 886 llvm::dbgs() << ")"; 887 } else { 888 assert(expr.kind == TensorExp::Kind::kDenseOp); 889 } 890 break; 891 } 892 } 893 894 void Merger::dumpLat(LatPointId p) const { 895 const auto &point = lat(p); 896 llvm::dbgs() << "lat("; 897 dumpBits(point.bits); 898 llvm::dbgs() << " :"; 899 dumpBits(point.simple); 900 llvm::dbgs() << " : "; 901 dumpExp(point.exp); 902 llvm::dbgs() << " )\n"; 903 } 904 905 void Merger::dumpSet(LatSetId s) const { 906 const auto &ss = set(s); 907 llvm::dbgs() << "{ #" << ss.size() << "\n"; 908 for (const LatPointId p : ss) { 909 llvm::dbgs() << " "; 910 dumpLat(p); 911 } 912 llvm::dbgs() << "}\n"; 913 } 914 915 void Merger::dumpBits(const BitVector &bits) const { 916 for (TensorLoopId b = 0, be = bits.size(); b < be; b++) { 917 if (bits[b]) { 918 const TensorId t = tensor(b); 919 const LoopId i = loop(b); 920 const auto lt = lvlTypes[t][i]; 921 if (isLvlWithNonTrivialIdxExp(b)) 922 llvm::dbgs() << " DEP_" << t << "_" << i; 923 else 924 llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt); 925 } 926 } 927 } 928 929 #endif // NDEBUG 930 931 //===----------------------------------------------------------------------===// 932 // Builder methods. 933 //===----------------------------------------------------------------------===// 934 935 LatSetId Merger::buildLattices(ExprId e, LoopId i) { 936 // NOTE: The `expr` reference will be invalidated by recursive calls 937 // (and any other method that may add new expressions); therefore, the 938 // code below must make sure to copy fields of `expr` into local variables 939 // before making any recursive calls. 940 const auto &expr = exp(e); 941 const TensorExp::Kind kind = expr.kind; 942 switch (kind) { 943 // Leaf. 944 case TensorExp::Kind::kTensor: 945 case TensorExp::Kind::kInvariant: 946 case TensorExp::Kind::kSynZero: 947 case TensorExp::Kind::kLoopVar: { 948 // Either the loop-var is really used in the tensor expression, or it is 949 // set to the undefined loop-var in that level. An invariant expression, 950 // a proper index value, and a truly dynamic sparse output tensor are set 951 // to a synthetic tensor with undefined indices only to ensure the 952 // iteration space is not skipped as a result of their contents. 953 const LatSetId s = addSet(); 954 TensorId t = syntheticTensor; 955 if (kind == TensorExp::Kind::kTensor) { 956 t = expr.tensor; 957 if (hasSparseOut && t == outTensor) 958 t = syntheticTensor; 959 } 960 latSets[s].push_back(addLat(t, i, e)); 961 return s; 962 } 963 // Unary operations. 964 case TensorExp::Kind::kAbsF: 965 case TensorExp::Kind::kAbsC: 966 case TensorExp::Kind::kAbsI: 967 case TensorExp::Kind::kCeilF: 968 case TensorExp::Kind::kFloorF: 969 case TensorExp::Kind::kSqrtF: 970 case TensorExp::Kind::kSqrtC: 971 case TensorExp::Kind::kExpm1F: 972 case TensorExp::Kind::kExpm1C: 973 case TensorExp::Kind::kLog1pF: 974 case TensorExp::Kind::kLog1pC: 975 case TensorExp::Kind::kSinF: 976 case TensorExp::Kind::kSinC: 977 case TensorExp::Kind::kTanhF: 978 case TensorExp::Kind::kTanhC: 979 case TensorExp::Kind::kNegF: 980 case TensorExp::Kind::kNegC: 981 case TensorExp::Kind::kNegI: 982 case TensorExp::Kind::kTruncF: 983 case TensorExp::Kind::kExtF: 984 case TensorExp::Kind::kCastFS: 985 case TensorExp::Kind::kCastFU: 986 case TensorExp::Kind::kCastSF: 987 case TensorExp::Kind::kCastUF: 988 case TensorExp::Kind::kCastS: 989 case TensorExp::Kind::kCastU: 990 case TensorExp::Kind::kCastIdx: 991 case TensorExp::Kind::kTruncI: 992 case TensorExp::Kind::kCIm: 993 case TensorExp::Kind::kCRe: 994 case TensorExp::Kind::kBitCast: 995 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 996 // lattice set of the operand through the operator into a new set. 997 // 998 // -y|!y | y | 999 // --+---+---+ 1000 // | 0 |-y | 1001 { 1002 const ExprId e0 = expr.children.e0; 1003 const Value v = expr.val; 1004 return mapSet(kind, buildLattices(e0, i), v); 1005 } 1006 case TensorExp::Kind::kBinaryBranch: 1007 case TensorExp::Kind::kSelect: 1008 // The left or right half of a binary operation which has already 1009 // been split into separate operations for each region. 1010 { 1011 const ExprId e0 = expr.children.e0; 1012 Operation *const op = expr.op; 1013 return mapSet(kind, buildLattices(e0, i), Value(), op); 1014 } 1015 case TensorExp::Kind::kUnary: 1016 // A custom unary operation. 1017 // 1018 // op y| !y | y | 1019 // ----+----------+------------+ 1020 // | absent() | present(y) | 1021 { 1022 const ExprId e0 = expr.children.e0; 1023 UnaryOp unop = cast<UnaryOp>(expr.op); 1024 const LatSetId child0 = buildLattices(e0, i); 1025 Region &absentRegion = unop.getAbsentRegion(); 1026 if (absentRegion.empty()) { 1027 // Simple mapping over existing values. 1028 return mapSet(kind, child0, Value(), unop); 1029 } 1030 // Use a disjunction with `unop` on the left and the absent value as an 1031 // invariant on the right. 1032 Block &absentBlock = absentRegion.front(); 1033 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); 1034 const Value absentVal = absentYield.getResult(); 1035 const ExprId rhs = addInvariantExp(absentVal); 1036 return disjSet(e, child0, buildLattices(rhs, i), unop); 1037 } 1038 // Binary operations. 1039 case TensorExp::Kind::kMulF: 1040 case TensorExp::Kind::kMulC: 1041 case TensorExp::Kind::kMulI: 1042 case TensorExp::Kind::kAndI: 1043 // A multiplicative operation only needs to be performed 1044 // for the conjunction of sparse iteration spaces. 1045 // 1046 // x*y|!y | y | 1047 // ---+---+---+ 1048 // !x | 0 | 0 | 1049 // x | 0 |x*y| 1050 // 1051 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. 1052 { 1053 const ExprId e0 = expr.children.e0; 1054 const ExprId e1 = expr.children.e1; 1055 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1056 } 1057 case TensorExp::Kind::kDivF: 1058 case TensorExp::Kind::kDivC: 1059 case TensorExp::Kind::kDivS: 1060 case TensorExp::Kind::kDivU: 1061 // A division is tricky, since 0/0, 0/c, c/0 all have 1062 // specific outcomes for floating-point and integers. 1063 // Thus, we need to traverse the full iteration space. 1064 // 1065 // x/y|!y | y | 1066 // ---+---+---+ 1067 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 1068 // x |x/0|x/y| INT: x/0=exception for any x 1069 // 1070 // TODO: for now we "fixed" this by only accepting x/c cases 1071 // during expression building, so that the conjunction 1072 // rules applies (viz. x/c = x*(1/c) as far as lattice 1073 // construction is concerned). 1074 { 1075 const ExprId e0 = expr.children.e0; 1076 const ExprId e1 = expr.children.e1; 1077 assert(!maybeZero(e1)); 1078 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1079 } 1080 case TensorExp::Kind::kAddF: 1081 case TensorExp::Kind::kAddC: 1082 case TensorExp::Kind::kAddI: 1083 case TensorExp::Kind::kSubF: 1084 case TensorExp::Kind::kSubC: 1085 case TensorExp::Kind::kSubI: 1086 case TensorExp::Kind::kOrI: 1087 case TensorExp::Kind::kXorI: 1088 // An additive operation needs to be performed 1089 // for the disjunction of sparse iteration spaces. 1090 // 1091 // x+y|!y | y | x-y|!y | y | 1092 // ---+---+---+ ---+---+---+ 1093 // !x | 0 | y | !x | 0 |-y | 1094 // x | x |x+y| x | x |x-y| 1095 { 1096 const ExprId e0 = expr.children.e0; 1097 const ExprId e1 = expr.children.e1; 1098 return disjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1099 } 1100 case TensorExp::Kind::kCmpF: 1101 case TensorExp::Kind::kCmpI: 1102 // A comparison operation needs to be performed 1103 // for the disjunction of sparse iteration spaces. 1104 // 1105 // x < y | !y | y | 1106 // -------+-------+-------+ 1107 // !x | 0 | 0 < y | 1108 // x | x < 0 | x < y | 1109 { 1110 const ExprId e0 = expr.children.e0; 1111 const ExprId e1 = expr.children.e1; 1112 return disjSetWithZero(e, buildLattices(e0, i), buildLattices(e1, i)); 1113 } 1114 case TensorExp::Kind::kShrS: 1115 case TensorExp::Kind::kShrU: 1116 case TensorExp::Kind::kShlI: 1117 // A shift operation by an invariant amount (viz. tensor expressions 1118 // can only occur at the left-hand-side of the operator) can be handled 1119 // with the conjunction rule. 1120 { 1121 const ExprId e0 = expr.children.e0; 1122 const ExprId e1 = expr.children.e1; 1123 assert(isInvariant(e1)); 1124 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1125 } 1126 case TensorExp::Kind::kBinary: 1127 // A custom binary operation. 1128 // 1129 // x op y| !y | y | 1130 // ------+---------+--------------+ 1131 // !x | empty | right(y) | 1132 // x | left(x) | overlap(x,y) | 1133 { 1134 const ExprId e0 = expr.children.e0; 1135 const ExprId e1 = expr.children.e1; 1136 BinaryOp binop = cast<BinaryOp>(expr.op); 1137 const LatSetId child0 = buildLattices(e0, i); 1138 const LatSetId child1 = buildLattices(e1, i); 1139 Region &leftRegion = binop.getLeftRegion(); 1140 Region &rightRegion = binop.getRightRegion(); 1141 // Left Region. 1142 Operation *leftYield = nullptr; 1143 if (!leftRegion.empty()) { 1144 Block &leftBlock = leftRegion.front(); 1145 leftYield = leftBlock.getTerminator(); 1146 } 1147 // Right Region. 1148 Operation *rightYield = nullptr; 1149 if (!rightRegion.empty()) { 1150 Block &rightBlock = rightRegion.front(); 1151 rightYield = rightBlock.getTerminator(); 1152 } 1153 bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty(); 1154 bool includeRight = binop.getRightIdentity() || !rightRegion.empty(); 1155 return combiSet(e, child0, child1, binop, includeLeft, 1156 TensorExp::Kind::kBinaryBranch, leftYield, includeRight, 1157 TensorExp::Kind::kBinaryBranch, rightYield); 1158 } 1159 case TensorExp::Kind::kReduce: 1160 // A custom reduce operation. 1161 { 1162 const ExprId e0 = expr.children.e0; 1163 const ExprId e1 = expr.children.e1; 1164 Operation *const op = expr.op; 1165 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op); 1166 } 1167 case TensorExp::Kind::kDenseOp: { 1168 // It does not really matter whether we use conjunctive/disjunctive set 1169 // here, as all the operands of kDenseOp must be dense, the disjunctive set 1170 // will be optimized into conjunctive set eventually. 1171 if (expr.children.e1 == detail::kInvalidId) { 1172 const ExprId e0 = expr.children.e0; 1173 Operation *const op = expr.op; 1174 return mapSet(kind, buildLattices(e0, i), Value(), op); 1175 } 1176 1177 const ExprId e0 = expr.children.e0; 1178 const ExprId e1 = expr.children.e1; 1179 Operation *const op = expr.op; 1180 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op); 1181 } 1182 } 1183 llvm_unreachable("unexpected expression kind"); 1184 } 1185 1186 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 1187 // Build the linalg semantics backward from yield. 1188 Operation *yield = op.getRegion().front().getTerminator(); 1189 assert(isa<linalg::YieldOp>(yield)); 1190 return buildTensorExp(op, yield->getOperand(0)).first; 1191 } 1192 1193 /// Only returns false if we are certain this is a nonzero. 1194 bool Merger::maybeZero(ExprId e) const { 1195 const auto &expr = exp(e); 1196 if (expr.kind == TensorExp::Kind::kInvariant) { 1197 if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) { 1198 ArrayAttr arrayAttr = c.getValue(); 1199 return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() && 1200 cast<FloatAttr>(arrayAttr[1]).getValue().isZero(); 1201 } 1202 if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>()) 1203 return c.value() == 0; 1204 if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>()) 1205 return c.value().isZero(); 1206 } 1207 return true; 1208 } 1209 1210 Type Merger::inferType(ExprId e, Value src) const { 1211 // Obtain the destination type from the cast node. 1212 Type dtp = exp(e).val.getType(); 1213 // Inspect source type. For vector types, apply the same 1214 // vectorization to the destination type. 1215 if (auto vtp = dyn_cast<VectorType>(src.getType())) 1216 return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims()); 1217 return dtp; 1218 } 1219 1220 /// Ensures that the sparsifier can generate code for expression. 1221 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) { 1222 // Arguments are always admissible. 1223 if (isa<BlockArgument>(v)) 1224 return true; 1225 // Accept index anywhere. 1226 Operation *def = v.getDefiningOp(); 1227 if (isa<linalg::IndexOp>(def)) 1228 return true; 1229 // Operation defined outside branch. 1230 if (def->getBlock() != block) 1231 return def->getBlock() != op->getBlock(); // invariant? 1232 // Operation defined within branch. Anything is accepted, 1233 // as long as all subexpressions are admissible. 1234 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) 1235 if (!isAdmissibleBranchExp(op, block, def->getOperand(i))) 1236 return false; 1237 return true; 1238 } 1239 1240 /// Ensures that the sparsifier can generate code for branch. 1241 static bool isAdmissibleBranch(Operation *op, Region ®ion) { 1242 if (region.empty()) 1243 return true; 1244 // Build the semi-ring branch semantics backward from yield. 1245 Operation *yield = region.front().getTerminator(); 1246 assert(isa<YieldOp>(yield)); 1247 return isAdmissibleBranchExp(op, ®ion.front(), yield->getOperand(0)); 1248 } 1249 1250 std::pair<std::optional<ExprId>, bool> 1251 Merger::buildTensorExp(linalg::GenericOp op, Value v) { 1252 // Recursion leaves. 1253 if (auto arg = dyn_cast<BlockArgument>(v)) { 1254 const TensorId tid = makeTensorId(arg.getArgNumber()); 1255 // Any argument of the generic op that is not marked as a scalar 1256 // argument is considered a tensor, indexed by the implicit loop 1257 // bounds. This includes rank-0 tensor arguments. 1258 if (arg.getOwner()->getParentOp() == op) { 1259 OpOperand &t = op->getOpOperand(tid); 1260 bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr; 1261 if (!op.isScalar(&t)) 1262 return {addTensorExp(tid), hasSpDep}; 1263 v = t.get(); // get scalar value 1264 } 1265 // Any other argument (marked as scalar argument for the generic op 1266 // or belonging to an enveloping op) is considered invariant. 1267 return {addInvariantExp(v), /*hasSpDep=*/false}; 1268 } 1269 // Something defined outside is invariant. 1270 Operation *def = v.getDefiningOp(); 1271 if (def->getBlock() != &op.getRegion().front()) 1272 return {addInvariantExp(v), /*hasSpDep=*/false}; 1273 // Construct index operations. 1274 if (def->getNumOperands() == 0) { 1275 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 1276 return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false}; 1277 } 1278 1279 // Construct unary operations if subexpression can be built. 1280 if (def->getNumOperands() == 1) { 1281 const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0)); 1282 if (x.has_value()) { 1283 const ExprId e = *x; 1284 if (isa<math::AbsFOp>(def)) 1285 return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep}; 1286 if (isa<complex::AbsOp>(def)) 1287 return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep}; 1288 if (isa<math::AbsIOp>(def)) 1289 return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep}; 1290 if (isa<math::CeilOp>(def)) 1291 return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep}; 1292 if (isa<math::FloorOp>(def)) 1293 return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep}; 1294 if (isa<math::SqrtOp>(def)) 1295 return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep}; 1296 if (isa<complex::SqrtOp>(def)) 1297 return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep}; 1298 if (isa<math::ExpM1Op>(def)) 1299 return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep}; 1300 if (isa<complex::Expm1Op>(def)) 1301 return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep}; 1302 if (isa<math::Log1pOp>(def)) 1303 return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep}; 1304 if (isa<complex::Log1pOp>(def)) 1305 return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep}; 1306 if (isa<math::SinOp>(def)) 1307 return {addExp(TensorExp::Kind::kSinF, e), hasSpDep}; 1308 if (isa<complex::SinOp>(def)) 1309 return {addExp(TensorExp::Kind::kSinC, e), hasSpDep}; 1310 if (isa<math::TanhOp>(def)) 1311 return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep}; 1312 if (isa<complex::TanhOp>(def)) 1313 return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep}; 1314 if (isa<arith::NegFOp>(def)) 1315 return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std 1316 if (isa<complex::NegOp>(def)) 1317 return {addExp(TensorExp::Kind::kNegC, e), hasSpDep}; 1318 if (isa<arith::TruncFOp>(def)) 1319 return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep}; 1320 if (isa<arith::ExtFOp>(def)) 1321 return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep}; 1322 if (isa<arith::FPToSIOp>(def)) 1323 return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep}; 1324 if (isa<arith::FPToUIOp>(def)) 1325 return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep}; 1326 if (isa<arith::SIToFPOp>(def)) 1327 return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep}; 1328 if (isa<arith::UIToFPOp>(def)) 1329 return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep}; 1330 if (isa<arith::ExtSIOp>(def)) 1331 return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep}; 1332 if (isa<arith::ExtUIOp>(def)) 1333 return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep}; 1334 if (isa<arith::IndexCastOp>(def)) 1335 return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep}; 1336 if (isa<arith::TruncIOp>(def)) 1337 return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep}; 1338 if (isa<complex::ImOp>(def)) 1339 return {addExp(TensorExp::Kind::kCIm, e), hasSpDep}; 1340 if (isa<complex::ReOp>(def)) 1341 return {addExp(TensorExp::Kind::kCRe, e), hasSpDep}; 1342 if (isa<arith::BitcastOp>(def)) 1343 return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep}; 1344 if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) { 1345 if (isAdmissibleBranch(unop, unop.getPresentRegion()) && 1346 isAdmissibleBranch(unop, unop.getAbsentRegion())) 1347 return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep}; 1348 } 1349 if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) { 1350 if (isAdmissibleBranch(selop, selop.getRegion())) 1351 return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep}; 1352 } 1353 } 1354 } 1355 // Construct binary operations if subexpressions can be built. 1356 // See buildLattices() for an explanation of rejecting certain 1357 // division and shift operations. 1358 if (def->getNumOperands() == 2) { 1359 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0)); 1360 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1)); 1361 bool hasSpDep = xDepSp || yDepSp; 1362 if (x.has_value() && y.has_value()) { 1363 const ExprId e0 = *x; 1364 const ExprId e1 = *y; 1365 if (isa<arith::MulFOp>(def)) 1366 return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep}; 1367 if (isa<complex::MulOp>(def)) 1368 return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep}; 1369 if (isa<arith::MulIOp>(def)) 1370 return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep}; 1371 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 1372 return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep}; 1373 if (isa<complex::DivOp>(def) && !maybeZero(e1)) 1374 return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep}; 1375 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 1376 return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep}; 1377 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 1378 return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep}; 1379 if (isa<arith::AddFOp>(def)) 1380 return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep}; 1381 if (isa<complex::AddOp>(def)) 1382 return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep}; 1383 if (isa<arith::AddIOp>(def)) 1384 return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep}; 1385 if (isa<arith::SubFOp>(def)) 1386 return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep}; 1387 if (isa<complex::SubOp>(def)) 1388 return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep}; 1389 if (isa<arith::SubIOp>(def)) 1390 return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep}; 1391 if (isa<arith::AndIOp>(def)) 1392 return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep}; 1393 if (isa<arith::OrIOp>(def)) 1394 return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep}; 1395 if (isa<arith::XOrIOp>(def)) 1396 return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep}; 1397 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 1398 return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep}; 1399 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 1400 return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep}; 1401 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 1402 return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep}; 1403 if (auto ci = dyn_cast<arith::CmpIOp>(def)) { 1404 if (ci.getPredicate() == arith::CmpIPredicate::eq && 1405 ci.getPredicate() == arith::CmpIPredicate::sle && 1406 ci.getPredicate() == arith::CmpIPredicate::sge && 1407 ci.getPredicate() == arith::CmpIPredicate::ule && 1408 ci.getPredicate() == arith::CmpIPredicate::uge) { 1409 // We can not sparsify comparison with equal, this is because 0 <= 0 1410 // yields true, and thus densifies the result. 1411 return {std::nullopt, false}; 1412 } 1413 1414 auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr, 1415 ci.getPredicateAttr()); 1416 return {e, hasSpDep}; 1417 } 1418 if (auto cf = dyn_cast<arith::CmpFOp>(def)) { 1419 if (cf.getPredicate() == arith::CmpFPredicate::OEQ && 1420 cf.getPredicate() == arith::CmpFPredicate::OGE && 1421 cf.getPredicate() == arith::CmpFPredicate::OLE && 1422 cf.getPredicate() == arith::CmpFPredicate::ONE && 1423 cf.getPredicate() == arith::CmpFPredicate::UEQ && 1424 cf.getPredicate() == arith::CmpFPredicate::UGE && 1425 cf.getPredicate() == arith::CmpFPredicate::ULE && 1426 cf.getPredicate() == arith::CmpFPredicate::ORD && 1427 cf.getPredicate() == arith::CmpFPredicate::UNO) { 1428 // We can not sparsify comparison with equal, this is because 0 <= 0 1429 // yields true, and thus densifies the result. 1430 return {std::nullopt, false}; 1431 } 1432 auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr, 1433 cf.getPredicateAttr()); 1434 return {e, hasSpDep}; 1435 } 1436 if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) { 1437 if (isAdmissibleBranch(binop, binop.getOverlapRegion()) && 1438 (binop.getLeftIdentity() || 1439 isAdmissibleBranch(binop, binop.getLeftRegion())) && 1440 (binop.getRightIdentity() || 1441 isAdmissibleBranch(binop, binop.getRightRegion()))) 1442 return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep}; 1443 } 1444 } 1445 } 1446 // Construct ternary operations if subexpressions can be built. 1447 if (def->getNumOperands() == 3) { 1448 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0)); 1449 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1)); 1450 const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2)); 1451 bool hasSpDep = xDepSp || yDepSp || zDepSp; 1452 if (x.has_value() && y.has_value() && z.has_value()) { 1453 const ExprId e0 = *x; 1454 const ExprId e1 = *y; 1455 if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) { 1456 if (isAdmissibleBranch(redop, redop.getRegion())) 1457 return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep}; 1458 } 1459 } 1460 } 1461 1462 // If we reach here, we are dealing with an operation that is not currently 1463 // sparsifiable. We can still generate code for it if all its operands only 1464 // have dense dependencies (i.e., all the values are loaded from dense 1465 // tensors). 1466 if (def->getNumResults() != 1) // only handle single result operation. 1467 return {std::nullopt, false}; 1468 1469 SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp; 1470 // Builds all the sub-expressions 1471 for (Value operand : def->getOperands()) 1472 subExp.push_back(buildTensorExp(op, operand)); 1473 1474 if (llvm::all_of(subExp, 1475 [](auto e) { return e.first.has_value() && !e.second; })) { 1476 // All the subexpressions can be built and has *no* sparse dependencies. 1477 if (subExp.size() == 2) { 1478 auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first, 1479 *subExp[1].first, def); 1480 return {e, false}; 1481 } 1482 if (subExp.size() == 1) { 1483 auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first, 1484 detail::kInvalidId, def); 1485 return {e, false}; 1486 } 1487 } 1488 // Cannot build. 1489 return {std::nullopt, false}; 1490 } 1491 1492 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, 1493 ValueRange vals) { 1494 // Make a clone of overlap region. 1495 Region tmpRegion; 1496 IRMapping mapper; 1497 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); 1498 Block &clonedBlock = tmpRegion.front(); 1499 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); 1500 // Merge cloned block and return yield value. 1501 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1502 rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals); 1503 Value val = clonedYield.getResult(); 1504 rewriter.eraseOp(clonedYield); 1505 rewriter.eraseOp(placeholder); 1506 return val; 1507 } 1508 1509 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, 1510 Operation *op, Value v0) { 1511 if (!v0) 1512 // Empty input value must be propagated. 1513 return Value(); 1514 UnaryOp unop = cast<UnaryOp>(op); 1515 Region &presentRegion = unop.getPresentRegion(); 1516 if (presentRegion.empty()) 1517 // Uninitialized Value() will be interpreted as missing data in the 1518 // output. 1519 return Value(); 1520 return insertYieldOp(rewriter, loc, presentRegion, {v0}); 1521 } 1522 1523 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, 1524 Operation *op, Value v0, Value v1) { 1525 if (!v0 || !v1) 1526 // Empty input values must be propagated. 1527 return Value(); 1528 BinaryOp binop = cast<BinaryOp>(op); 1529 Region &overlapRegion = binop.getOverlapRegion(); 1530 if (overlapRegion.empty()) 1531 // Uninitialized Value() will be interpreted as missing data in the 1532 // output. 1533 return Value(); 1534 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); 1535 } 1536 1537 Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, 1538 Value v1) const { 1539 const auto &expr = exp(e); 1540 switch (expr.kind) { 1541 // Leaf. 1542 case TensorExp::Kind::kTensor: 1543 case TensorExp::Kind::kInvariant: 1544 case TensorExp::Kind::kLoopVar: 1545 case TensorExp::Kind::kSynZero: 1546 llvm_unreachable("unexpected non-op"); 1547 // Unary operations. 1548 case TensorExp::Kind::kAbsF: 1549 return rewriter.create<math::AbsFOp>(loc, v0); 1550 case TensorExp::Kind::kAbsC: { 1551 auto type = cast<ComplexType>(v0.getType()); 1552 auto eltType = cast<FloatType>(type.getElementType()); 1553 return rewriter.create<complex::AbsOp>(loc, eltType, v0); 1554 } 1555 case TensorExp::Kind::kAbsI: 1556 return rewriter.create<math::AbsIOp>(loc, v0); 1557 case TensorExp::Kind::kCeilF: 1558 return rewriter.create<math::CeilOp>(loc, v0); 1559 case TensorExp::Kind::kFloorF: 1560 return rewriter.create<math::FloorOp>(loc, v0); 1561 case TensorExp::Kind::kSqrtF: 1562 return rewriter.create<math::SqrtOp>(loc, v0); 1563 case TensorExp::Kind::kSqrtC: 1564 return rewriter.create<complex::SqrtOp>(loc, v0); 1565 case TensorExp::Kind::kExpm1F: 1566 return rewriter.create<math::ExpM1Op>(loc, v0); 1567 case TensorExp::Kind::kExpm1C: 1568 return rewriter.create<complex::Expm1Op>(loc, v0); 1569 case TensorExp::Kind::kLog1pF: 1570 return rewriter.create<math::Log1pOp>(loc, v0); 1571 case TensorExp::Kind::kLog1pC: 1572 return rewriter.create<complex::Log1pOp>(loc, v0); 1573 case TensorExp::Kind::kSinF: 1574 return rewriter.create<math::SinOp>(loc, v0); 1575 case TensorExp::Kind::kSinC: 1576 return rewriter.create<complex::SinOp>(loc, v0); 1577 case TensorExp::Kind::kTanhF: 1578 return rewriter.create<math::TanhOp>(loc, v0); 1579 case TensorExp::Kind::kTanhC: 1580 return rewriter.create<complex::TanhOp>(loc, v0); 1581 case TensorExp::Kind::kNegF: 1582 return rewriter.create<arith::NegFOp>(loc, v0); 1583 case TensorExp::Kind::kNegC: 1584 return rewriter.create<complex::NegOp>(loc, v0); 1585 case TensorExp::Kind::kNegI: // no negi in std 1586 return rewriter.create<arith::SubIOp>( 1587 loc, 1588 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 1589 rewriter.getZeroAttr(v0.getType())), 1590 v0); 1591 case TensorExp::Kind::kTruncF: 1592 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 1593 case TensorExp::Kind::kExtF: 1594 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 1595 case TensorExp::Kind::kCastFS: 1596 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 1597 case TensorExp::Kind::kCastFU: 1598 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 1599 case TensorExp::Kind::kCastSF: 1600 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 1601 case TensorExp::Kind::kCastUF: 1602 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 1603 case TensorExp::Kind::kCastS: 1604 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 1605 case TensorExp::Kind::kCastU: 1606 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 1607 case TensorExp::Kind::kCastIdx: 1608 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 1609 case TensorExp::Kind::kTruncI: 1610 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 1611 case TensorExp::Kind::kCIm: { 1612 auto type = cast<ComplexType>(v0.getType()); 1613 auto eltType = cast<FloatType>(type.getElementType()); 1614 return rewriter.create<complex::ImOp>(loc, eltType, v0); 1615 } 1616 case TensorExp::Kind::kCRe: { 1617 auto type = cast<ComplexType>(v0.getType()); 1618 auto eltType = cast<FloatType>(type.getElementType()); 1619 return rewriter.create<complex::ReOp>(loc, eltType, v0); 1620 } 1621 case TensorExp::Kind::kBitCast: 1622 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 1623 // Binary operations. 1624 case TensorExp::Kind::kMulF: 1625 return rewriter.create<arith::MulFOp>(loc, v0, v1); 1626 case TensorExp::Kind::kMulC: 1627 return rewriter.create<complex::MulOp>(loc, v0, v1); 1628 case TensorExp::Kind::kMulI: 1629 return rewriter.create<arith::MulIOp>(loc, v0, v1); 1630 case TensorExp::Kind::kDivF: 1631 return rewriter.create<arith::DivFOp>(loc, v0, v1); 1632 case TensorExp::Kind::kDivC: 1633 return rewriter.create<complex::DivOp>(loc, v0, v1); 1634 case TensorExp::Kind::kDivS: 1635 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 1636 case TensorExp::Kind::kDivU: 1637 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 1638 case TensorExp::Kind::kAddF: 1639 return rewriter.create<arith::AddFOp>(loc, v0, v1); 1640 case TensorExp::Kind::kAddC: 1641 return rewriter.create<complex::AddOp>(loc, v0, v1); 1642 case TensorExp::Kind::kAddI: 1643 return rewriter.create<arith::AddIOp>(loc, v0, v1); 1644 case TensorExp::Kind::kSubF: 1645 return rewriter.create<arith::SubFOp>(loc, v0, v1); 1646 case TensorExp::Kind::kSubC: 1647 return rewriter.create<complex::SubOp>(loc, v0, v1); 1648 case TensorExp::Kind::kSubI: 1649 return rewriter.create<arith::SubIOp>(loc, v0, v1); 1650 case TensorExp::Kind::kAndI: 1651 return rewriter.create<arith::AndIOp>(loc, v0, v1); 1652 case TensorExp::Kind::kOrI: 1653 return rewriter.create<arith::OrIOp>(loc, v0, v1); 1654 case TensorExp::Kind::kXorI: 1655 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 1656 case TensorExp::Kind::kShrS: 1657 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 1658 case TensorExp::Kind::kShrU: 1659 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 1660 case TensorExp::Kind::kShlI: 1661 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 1662 case TensorExp::Kind::kCmpI: { 1663 auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr); 1664 return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1); 1665 } 1666 case TensorExp::Kind::kCmpF: { 1667 auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr); 1668 return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1); 1669 } 1670 case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic. 1671 return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(), 1672 {v0}); 1673 case TensorExp::Kind::kUnary: 1674 return buildUnaryPresent(rewriter, loc, expr.op, v0); 1675 case TensorExp::Kind::kSelect: 1676 return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(), 1677 {v0}); 1678 case TensorExp::Kind::kBinary: 1679 return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1); 1680 case TensorExp::Kind::kReduce: { 1681 ReduceOp redOp = cast<ReduceOp>(expr.op); 1682 return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1}); 1683 } 1684 case TensorExp::Kind::kDenseOp: { 1685 Operation *actualOp = expr.op; 1686 IRMapping mapping; 1687 mapping.map(actualOp->getOperand(0), v0); 1688 if (actualOp->getNumOperands() == 2) 1689 mapping.map(actualOp->getOperand(1), v1); 1690 return rewriter.clone(*actualOp, mapping)->getResult(0); 1691 } 1692 } 1693 llvm_unreachable("unexpected expression kind in build"); 1694 } 1695 1696 } // namespace sparse_tensor 1697 } // namespace mlir 1698