1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arith/IR/Arith.h" 11 #include "mlir/Dialect/Complex/IR/Complex.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 15 #include "mlir/IR/Operation.h" 16 #include "llvm/Support/Debug.h" 17 #include <optional> 18 19 namespace mlir { 20 namespace sparse_tensor { 21 22 enum class ExpArity { 23 kNullary, 24 kUnary, 25 kBinary, 26 }; 27 28 static ExpArity getExpArity(TensorExp::Kind k) { 29 switch (k) { 30 // Leaf. 31 case TensorExp::Kind::kTensor: 32 case TensorExp::Kind::kInvariant: 33 case TensorExp::Kind::kLoopVar: 34 case TensorExp::Kind::kSynZero: 35 return ExpArity::kNullary; 36 case TensorExp::Kind::kAbsF: 37 case TensorExp::Kind::kAbsC: 38 case TensorExp::Kind::kAbsI: 39 case TensorExp::Kind::kCeilF: 40 case TensorExp::Kind::kFloorF: 41 case TensorExp::Kind::kSqrtF: 42 case TensorExp::Kind::kSqrtC: 43 case TensorExp::Kind::kExpm1F: 44 case TensorExp::Kind::kExpm1C: 45 case TensorExp::Kind::kLog1pF: 46 case TensorExp::Kind::kLog1pC: 47 case TensorExp::Kind::kSinF: 48 case TensorExp::Kind::kSinC: 49 case TensorExp::Kind::kTanhF: 50 case TensorExp::Kind::kTanhC: 51 case TensorExp::Kind::kTruncF: 52 case TensorExp::Kind::kExtF: 53 case TensorExp::Kind::kCastFS: 54 case TensorExp::Kind::kCastFU: 55 case TensorExp::Kind::kCastSF: 56 case TensorExp::Kind::kCastUF: 57 case TensorExp::Kind::kCastS: 58 case TensorExp::Kind::kCastU: 59 case TensorExp::Kind::kCastIdx: 60 case TensorExp::Kind::kTruncI: 61 case TensorExp::Kind::kCIm: 62 case TensorExp::Kind::kCRe: 63 case TensorExp::Kind::kBitCast: 64 case TensorExp::Kind::kBinaryBranch: 65 case TensorExp::Kind::kUnary: 66 case TensorExp::Kind::kSelect: 67 case TensorExp::Kind::kNegF: 68 case TensorExp::Kind::kNegC: 69 case TensorExp::Kind::kNegI: 70 return ExpArity::kUnary; 71 // Binary operations. 72 case TensorExp::Kind::kDivF: 73 case TensorExp::Kind::kDivC: 74 case TensorExp::Kind::kDivS: 75 case TensorExp::Kind::kDivU: 76 case TensorExp::Kind::kShrS: 77 case TensorExp::Kind::kShrU: 78 case TensorExp::Kind::kShlI: 79 case TensorExp::Kind::kMulF: 80 case TensorExp::Kind::kMulC: 81 case TensorExp::Kind::kMulI: 82 case TensorExp::Kind::kAndI: 83 case TensorExp::Kind::kAddF: 84 case TensorExp::Kind::kAddC: 85 case TensorExp::Kind::kAddI: 86 case TensorExp::Kind::kOrI: 87 case TensorExp::Kind::kXorI: 88 case TensorExp::Kind::kBinary: 89 case TensorExp::Kind::kReduce: 90 case TensorExp::Kind::kSubF: 91 case TensorExp::Kind::kSubC: 92 case TensorExp::Kind::kSubI: 93 case TensorExp::Kind::kCmpF: 94 case TensorExp::Kind::kCmpI: 95 case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands 96 return ExpArity::kBinary; 97 } 98 llvm_unreachable("unexpected kind"); 99 } 100 101 //===----------------------------------------------------------------------===// 102 // Constructors. 103 //===----------------------------------------------------------------------===// 104 105 TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v, 106 Operation *o, Attribute a) 107 : kind(k), val(v), op(o) { 108 switch (kind) { 109 // Leaf. 110 case TensorExp::Kind::kTensor: 111 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 112 tensor = x; 113 return; 114 case TensorExp::Kind::kSynZero: 115 assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o); 116 return; 117 case TensorExp::Kind::kInvariant: 118 assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o); 119 return; 120 case TensorExp::Kind::kLoopVar: 121 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 122 loop = x; 123 return; 124 // Unary operations. 125 case TensorExp::Kind::kAbsF: 126 case TensorExp::Kind::kAbsC: 127 case TensorExp::Kind::kAbsI: 128 case TensorExp::Kind::kCeilF: 129 case TensorExp::Kind::kFloorF: 130 case TensorExp::Kind::kSqrtF: 131 case TensorExp::Kind::kSqrtC: 132 case TensorExp::Kind::kExpm1F: 133 case TensorExp::Kind::kExpm1C: 134 case TensorExp::Kind::kLog1pF: 135 case TensorExp::Kind::kLog1pC: 136 case TensorExp::Kind::kSinF: 137 case TensorExp::Kind::kSinC: 138 case TensorExp::Kind::kTanhF: 139 case TensorExp::Kind::kTanhC: 140 case TensorExp::Kind::kNegF: 141 case TensorExp::Kind::kNegC: 142 case TensorExp::Kind::kNegI: 143 case TensorExp::Kind::kCIm: 144 case TensorExp::Kind::kCRe: 145 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 146 children.e0 = x; 147 children.e1 = y; 148 return; 149 case TensorExp::Kind::kTruncF: 150 case TensorExp::Kind::kExtF: 151 case TensorExp::Kind::kCastFS: 152 case TensorExp::Kind::kCastFU: 153 case TensorExp::Kind::kCastSF: 154 case TensorExp::Kind::kCastUF: 155 case TensorExp::Kind::kCastS: 156 case TensorExp::Kind::kCastU: 157 case TensorExp::Kind::kCastIdx: 158 case TensorExp::Kind::kTruncI: 159 case TensorExp::Kind::kBitCast: 160 assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o); 161 children.e0 = x; 162 children.e1 = y; 163 return; 164 case TensorExp::Kind::kBinaryBranch: 165 case TensorExp::Kind::kSelect: 166 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o); 167 children.e0 = x; 168 children.e1 = y; 169 return; 170 case TensorExp::Kind::kUnary: 171 // No assertion on y can be made, as the branching paths involve both 172 // a unary (`mapSet`) and binary (`disjSet`) pathway. 173 assert(x != detail::kInvalidId && !v && o); 174 children.e0 = x; 175 children.e1 = y; 176 return; 177 // Binary operations. 178 case TensorExp::Kind::kMulF: 179 case TensorExp::Kind::kMulC: 180 case TensorExp::Kind::kMulI: 181 case TensorExp::Kind::kDivF: 182 case TensorExp::Kind::kDivC: 183 case TensorExp::Kind::kDivS: 184 case TensorExp::Kind::kDivU: 185 case TensorExp::Kind::kAddF: 186 case TensorExp::Kind::kAddC: 187 case TensorExp::Kind::kAddI: 188 case TensorExp::Kind::kSubF: 189 case TensorExp::Kind::kSubC: 190 case TensorExp::Kind::kSubI: 191 case TensorExp::Kind::kAndI: 192 case TensorExp::Kind::kOrI: 193 case TensorExp::Kind::kXorI: 194 case TensorExp::Kind::kShrS: 195 case TensorExp::Kind::kShrU: 196 case TensorExp::Kind::kShlI: 197 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); 198 children.e0 = x; 199 children.e1 = y; 200 return; 201 case TensorExp::Kind::kCmpF: 202 case TensorExp::Kind::kCmpI: 203 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); 204 attr = a; 205 children.e0 = x; 206 children.e1 = y; 207 return; 208 case TensorExp::Kind::kBinary: 209 case TensorExp::Kind::kReduce: 210 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o); 211 children.e0 = x; 212 children.e1 = y; 213 return; 214 case TensorExp::Kind::kDenseOp: 215 assert(x != detail::kInvalidId && !v && o); 216 children.e0 = x; 217 children.e1 = y; 218 return; 219 } 220 llvm_unreachable("unexpected kind"); 221 } 222 223 Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops, 224 unsigned maxLvlRank) 225 : outTensor(numInputOutputTensors - 1), 226 syntheticTensor(numInputOutputTensors), 227 numTensors(numInputOutputTensors + 1), numLoops(numLoops), 228 hasSparseOut(false), 229 lvlTypes(numTensors, std::vector<LevelType>(numLoops, LevelType::Undef)), 230 loopToLvl(numTensors, 231 std::vector<std::optional<Level>>(numLoops, std::nullopt)), 232 lvlToLoop(numTensors, 233 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)), 234 loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>( 235 numTensors, std::nullopt)), 236 levelToDependentLoop(numTensors, 237 std::vector<std::vector<LoopCoeffPair>>( 238 maxLvlRank, std::vector<LoopCoeffPair>())), 239 loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {} 240 241 //===----------------------------------------------------------------------===// 242 // Lattice methods. 243 //===----------------------------------------------------------------------===// 244 245 ExprId Merger::addTensorExp(TensorId t) { 246 assert(isValidTensorId(t)); 247 const ExprId eNew(tensorExps.size()); 248 tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId, 249 Value(), nullptr, nullptr); 250 return eNew; 251 } 252 253 ExprId Merger::addLoopVarExp(LoopId i) { 254 assert(isValidLoopId(i)); 255 const ExprId eNew(tensorExps.size()); 256 tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId, 257 Value(), nullptr, nullptr); 258 return eNew; 259 } 260 261 ExprId Merger::addInvariantExp(Value v) { 262 const ExprId eNew(tensorExps.size()); 263 tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId, 264 detail::kInvalidId, v, nullptr, nullptr); 265 return eNew; 266 } 267 268 ExprId Merger::addSynZeroExp() { 269 const ExprId eNew(tensorExps.size()); 270 tensorExps.emplace_back(TensorExp::Kind::kSynZero, detail::kInvalidId, 271 detail::kInvalidId, Value(), nullptr, nullptr); 272 return eNew; 273 } 274 275 ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op, 276 Attribute attr) { 277 assert(k > TensorExp::Kind::kLoopVar); 278 const ExprId eNew(tensorExps.size()); 279 tensorExps.emplace_back(k, e0, e1, Value(), op, attr); 280 return eNew; 281 } 282 283 ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op, 284 Attribute attr) { 285 assert(k > TensorExp::Kind::kLoopVar); 286 const ExprId eNew(tensorExps.size()); 287 tensorExps.emplace_back(k, e, detail::kInvalidId, v, op, attr); 288 return eNew; 289 } 290 291 LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) { 292 const LatPointId pNew(latPoints.size()); 293 const unsigned size = numLoops * numTensors; 294 const TensorLoopId b = makeTensorLoopId(t, i); 295 latPoints.emplace_back(size, e); 296 latPoints[pNew].bits.set(b); 297 return pNew; 298 } 299 300 LatPointId Merger::addLat(const BitVector &bits, ExprId e) { 301 assert(bits.size() == numLoops * numTensors); 302 const LatPointId pNew(latPoints.size()); 303 latPoints.emplace_back(bits, e); 304 return pNew; 305 } 306 307 LatSetId Merger::addSet() { 308 const LatSetId sNew(latSets.size()); 309 latSets.emplace_back(); 310 return sNew; 311 } 312 313 LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1, 314 Operation *op) { 315 TensorExp::Kind kind = exp(e).kind; 316 Attribute attr = exp(e).attr; 317 const LatPointId pNew(latPoints.size()); 318 const auto &point0 = lat(p0); 319 const auto &point1 = lat(p1); 320 BitVector bits(point0.bits); 321 bits |= point1.bits; 322 const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr); 323 latPoints.emplace_back(bits, ne); 324 return pNew; 325 } 326 327 LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { 328 const LatSetId sNew = addSet(); 329 auto &setNew = latSets[sNew]; 330 for (const LatPointId p0 : set(s0)) 331 for (const LatPointId p1 : set(s1)) 332 setNew.push_back(conjLat(e, p0, p1, op)); 333 return sNew; 334 } 335 336 LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { 337 const LatSetId sNew = conjSet(e, s0, s1, op); 338 TensorExp::Kind kind = exp(e).kind; 339 340 // Followed by all in s0. 341 latSets[sNew].append(latSets[s0]); 342 // Map binary 0-y to unary -y. 343 // TODO: move this if-else logic into buildLattices 344 if (kind == TensorExp::Kind::kSubF) 345 s1 = mapSet(TensorExp::Kind::kNegF, s1); 346 else if (kind == TensorExp::Kind::kSubC) 347 s1 = mapSet(TensorExp::Kind::kNegC, s1); 348 else if (kind == TensorExp::Kind::kSubI) 349 s1 = mapSet(TensorExp::Kind::kNegI, s1); 350 // Followed by all in s1. 351 latSets[sNew].append(latSets[s1]); 352 return sNew; 353 } 354 355 LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) { 356 assert(exp(e).kind == TensorExp::Kind::kCmpI || 357 exp(e).kind == TensorExp::Kind::kCmpF); 358 const LatSetId sNew = conjSet(e, s0, s1, nullptr); 359 360 ExprId e0 = exp(e).children.e0; 361 ExprId e1 = exp(e).children.e1; 362 if (exp(e0).kind == TensorExp::Kind::kSynZero || 363 exp(e1).kind == TensorExp::Kind::kSynZero) { 364 // lhs and rhs can't be synthetic zero at the same time. 365 assert(exp(e0).kind != exp(e1).kind); 366 // If one of the operands has already been assigned to zero (the 367 // element is absent in the corresponding operand), then we do not 368 // need to build disjunctive set for it. 369 return sNew; 370 } 371 372 auto lhsSet = mapBinWithSynZeroSet(e, s0, false); 373 auto rhsSet = mapBinWithSynZeroSet(e, s1, true); 374 latSets[sNew].append(latSets[lhsSet]); 375 latSets[sNew].append(latSets[rhsSet]); 376 return sNew; 377 } 378 379 LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, 380 bool includeLeft, TensorExp::Kind ltrans, 381 Operation *opleft, bool includeRight, 382 TensorExp::Kind rtrans, Operation *opright) { 383 const LatSetId sNew = conjSet(e, s0, s1, orig); 384 // Left Region. 385 if (includeLeft) { 386 if (opleft) 387 s0 = mapSet(ltrans, s0, Value(), opleft); 388 latSets[sNew].append(latSets[s0]); 389 } 390 // Right Region. 391 if (includeRight) { 392 if (opright) 393 s1 = mapSet(rtrans, s1, Value(), opright); 394 latSets[sNew].append(latSets[s1]); 395 } 396 return sNew; 397 } 398 399 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v, 400 Operation *op) { 401 assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) || 402 TensorExp::Kind::kDenseOp == kind); 403 const LatSetId sNew = addSet(); 404 auto &setNew = latSets[sNew]; 405 for (const LatPointId p : set(s0)) { 406 const auto &point = latPoints[p]; 407 setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op))); 408 } 409 return sNew; 410 } 411 412 LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) { 413 TensorExp::Kind kind = exp(e).kind; 414 Attribute a = exp(e).attr; 415 assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI); 416 // Must be a binary operation. 417 const LatSetId sNew = addSet(); 418 auto &setNew = latSets[sNew]; 419 const ExprId zeroExp = addSynZeroExp(); 420 for (const LatPointId p : set(s0)) { 421 const auto &point = latPoints[p]; 422 ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a) 423 : addExp(kind, point.exp, zeroExp, nullptr, a); 424 setNew.push_back(addLat(point.bits, newExp)); 425 } 426 return sNew; 427 } 428 429 LatSetId Merger::optimizeSet(LatSetId s0) { 430 const LatSetId sNew = addSet(); 431 auto &setNew = latSets[sNew]; 432 const auto &set0 = set(s0); 433 assert(!set0.empty()); 434 const LatPointId p0 = set0[0]; 435 for (const LatPointId p1 : set0) { 436 bool add = true; 437 if (p0 != p1) { 438 // Check whether this is a straightforward copy. 439 if (expIsTensor(latPoints[p1].exp, outTensor)) 440 continue; 441 // Check whether this conjunction is already covered. 442 for (const LatPointId p2 : setNew) { 443 assert(!latGT(p1, p2)); // Lj => Li would be bad 444 if (onlyDenseDiff(p2, p1)) { 445 add = false; 446 break; 447 } 448 } 449 assert(!add || latGT(p0, p1)); 450 } 451 if (add) 452 setNew.push_back(p1); 453 } 454 for (const LatPointId p : setNew) 455 latPoints[p].simple = simplifyCond(sNew, p); 456 return sNew; 457 } 458 459 BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { 460 // First determine if this lattice point is a *singleton*, i.e., 461 // the last point in a lattice, no other is less than this one. 462 bool isSingleton = true; 463 for (const LatPointId p1 : set(s0)) { 464 if (p0 != p1 && latGT(p0, p1)) { 465 isSingleton = false; 466 break; 467 } 468 } 469 470 BitVector simple(latPoints[p0].bits); 471 bool reset = isSingleton && hasAnySparse(simple); 472 const TensorLoopId be = simple.size(); 473 TensorLoopId offset = 0; // relative to the end 474 if (!reset) 475 // Starts resetting from a dense level, so that the first bit (if kept) 476 // is not undefined level-type. 477 for (unsigned b = 0; b < be; b++) { 478 if (simple[b] && isDenseLT(getLvlType(TensorLoopId{b}))) { 479 offset = be - b - 1; // relative to the end 480 break; 481 } 482 } 483 484 // Now apply the two basic rules. We also iterate the bits reversely to always 485 // keep the rightmost bit (which could possibly be a synthetic tensor). 486 for (unsigned b = be - 1 - offset, i = 0; i < be; 487 b = b == 0 ? be - 1 : b - 1, i++) { 488 // Slice on dense level has `locate` property as well, and can be optimized. 489 if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) { 490 const auto lt = getLvlType(b); 491 if (!isCompressedLT(lt) && !isSingletonLT(lt) && 492 !isLooseCompressedLT(lt) && !is2OutOf4LT(lt)) { 493 if (reset) 494 simple.reset(b); 495 reset = true; 496 } 497 } 498 } 499 return simple; 500 } 501 502 bool Merger::latGT(LatPointId i, LatPointId j) const { 503 const BitVector &bitsi = lat(i).bits; 504 const BitVector &bitsj = lat(j).bits; 505 assert(bitsi.size() == bitsj.size()); 506 if (bitsi.count() > bitsj.count()) { 507 for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) 508 if (bitsj[b] && !bitsi[b]) 509 return false; 510 return true; 511 } 512 return false; 513 } 514 515 bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { 516 BitVector tmp(latPoints[j].bits); 517 tmp ^= latPoints[i].bits; 518 return !hasAnySparse(tmp); 519 } 520 521 bool Merger::expContainsTensor(ExprId e, TensorId t) const { 522 const auto &expr = exp(e); 523 // First we check `expIsTensor`. 524 if (expr.kind == TensorExp::Kind::kTensor) 525 return expr.tensor == t; 526 527 switch (getExpArity(expr.kind)) { 528 case ExpArity::kNullary: 529 return false; 530 case ExpArity::kUnary: { 531 const ExprId e0 = expr.children.e0; 532 return expContainsTensor(e0, t); 533 } 534 case ExpArity::kBinary: { 535 const ExprId e0 = expr.children.e0; 536 const ExprId e1 = expr.children.e1; 537 return expContainsTensor(e0, t) || expContainsTensor(e1, t); 538 } 539 } 540 llvm_unreachable("unexpected arity"); 541 } 542 543 bool Merger::hasNegateOnOut(ExprId e) const { 544 const auto &expr = exp(e); 545 switch (expr.kind) { 546 case TensorExp::Kind::kNegF: 547 case TensorExp::Kind::kNegC: 548 case TensorExp::Kind::kNegI: 549 return expContainsTensor(expr.children.e0, outTensor); 550 case TensorExp::Kind::kSubF: 551 case TensorExp::Kind::kSubC: 552 case TensorExp::Kind::kSubI: 553 return expContainsTensor(expr.children.e1, outTensor) || 554 hasNegateOnOut(expr.children.e0); 555 case TensorExp::Kind::kDenseOp: { 556 bool lhsNeg = hasNegateOnOut(expr.children.e0); 557 if (!lhsNeg && expr.children.e1 != detail::kInvalidId) 558 return hasNegateOnOut(expr.children.e1); 559 return lhsNeg; 560 } 561 default: { 562 switch (getExpArity(expr.kind)) { 563 case ExpArity::kNullary: 564 return false; 565 case ExpArity::kUnary: 566 return hasNegateOnOut(expr.children.e0); 567 case ExpArity::kBinary: 568 return hasNegateOnOut(expr.children.e0) || 569 hasNegateOnOut(expr.children.e1); 570 } 571 } 572 } 573 llvm_unreachable("unexpected kind"); 574 } 575 576 bool Merger::isSingleCondition(TensorId t, ExprId e) const { 577 assert(isValidTensorId(t)); 578 const auto &expr = exp(e); 579 switch (expr.kind) { 580 // Leaf. 581 case TensorExp::Kind::kTensor: 582 return expr.tensor == t; 583 case TensorExp::Kind::kInvariant: 584 case TensorExp::Kind::kLoopVar: 585 case TensorExp::Kind::kSynZero: 586 return false; 587 // Unary operations. 588 case TensorExp::Kind::kAbsF: 589 case TensorExp::Kind::kAbsC: 590 case TensorExp::Kind::kAbsI: 591 case TensorExp::Kind::kCeilF: 592 case TensorExp::Kind::kFloorF: 593 case TensorExp::Kind::kSqrtF: 594 case TensorExp::Kind::kSqrtC: 595 case TensorExp::Kind::kExpm1F: 596 case TensorExp::Kind::kExpm1C: 597 case TensorExp::Kind::kLog1pF: 598 case TensorExp::Kind::kLog1pC: 599 case TensorExp::Kind::kSinF: 600 case TensorExp::Kind::kSinC: 601 case TensorExp::Kind::kTanhF: 602 case TensorExp::Kind::kTanhC: 603 case TensorExp::Kind::kNegF: 604 case TensorExp::Kind::kNegC: 605 case TensorExp::Kind::kNegI: 606 case TensorExp::Kind::kTruncF: 607 case TensorExp::Kind::kExtF: 608 case TensorExp::Kind::kCastFS: 609 case TensorExp::Kind::kCastFU: 610 case TensorExp::Kind::kCastSF: 611 case TensorExp::Kind::kCastUF: 612 case TensorExp::Kind::kCastS: 613 case TensorExp::Kind::kCastU: 614 case TensorExp::Kind::kCastIdx: 615 case TensorExp::Kind::kTruncI: 616 case TensorExp::Kind::kCIm: 617 case TensorExp::Kind::kCRe: 618 case TensorExp::Kind::kBitCast: 619 case TensorExp::Kind::kUnary: 620 return isSingleCondition(t, expr.children.e0); 621 case TensorExp::Kind::kBinaryBranch: 622 case TensorExp::Kind::kSelect: 623 return false; 624 // Binary operations. 625 case TensorExp::Kind::kDivF: // note: x / c only 626 case TensorExp::Kind::kDivC: 627 case TensorExp::Kind::kDivS: 628 case TensorExp::Kind::kDivU: 629 assert(!maybeZero(expr.children.e1)); 630 return isSingleCondition(t, expr.children.e0); 631 case TensorExp::Kind::kShrS: // note: x >> inv only 632 case TensorExp::Kind::kShrU: 633 case TensorExp::Kind::kShlI: 634 assert(isInvariant(expr.children.e1)); 635 return isSingleCondition(t, expr.children.e0); 636 case TensorExp::Kind::kMulF: 637 case TensorExp::Kind::kMulC: 638 case TensorExp::Kind::kMulI: 639 case TensorExp::Kind::kAndI: 640 case TensorExp::Kind::kReduce: 641 if (isSingleCondition(t, expr.children.e0)) 642 return isSingleCondition(t, expr.children.e1) || 643 isInvariant(expr.children.e1); 644 if (isSingleCondition(t, expr.children.e1)) 645 return isInvariant(expr.children.e0); 646 return false; 647 case TensorExp::Kind::kAddF: 648 case TensorExp::Kind::kAddC: 649 case TensorExp::Kind::kAddI: 650 return isSingleCondition(t, expr.children.e0) && 651 isSingleCondition(t, expr.children.e1); 652 case TensorExp::Kind::kSubF: 653 case TensorExp::Kind::kSubC: 654 case TensorExp::Kind::kSubI: 655 case TensorExp::Kind::kOrI: 656 case TensorExp::Kind::kXorI: 657 case TensorExp::Kind::kCmpF: 658 case TensorExp::Kind::kCmpI: 659 case TensorExp::Kind::kBinary: 660 return false; 661 case TensorExp::Kind::kDenseOp: 662 // Since Merger guarantees all the operands of the kDenseOp to be dense, the 663 // operation must be single-condition. 664 return true; 665 } 666 llvm_unreachable("unexpected kind"); 667 } 668 669 bool Merger::hasAnySparse(const BitVector &bits) const { 670 for (TensorLoopId b : bits.set_bits()) { 671 const auto lt = getLvlType(b); 672 if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) || 673 is2OutOf4LT(lt)) 674 return true; 675 } 676 return hasSparseIdxReduction(bits); 677 } 678 679 bool Merger::hasSparseIdxReduction(const BitVector &bits) const { 680 for (TensorLoopId b : bits.set_bits()) 681 if (isSparseLvlWithNonTrivialIdxExp(b)) 682 return true; 683 return false; 684 } 685 686 #ifndef NDEBUG 687 688 //===----------------------------------------------------------------------===// 689 // Print methods (for debugging). 690 //===----------------------------------------------------------------------===// 691 692 static const char *kindToOpSymbol(TensorExp::Kind kind) { 693 switch (kind) { 694 // Leaf. 695 case TensorExp::Kind::kTensor: 696 return "tensor"; 697 case TensorExp::Kind::kInvariant: 698 return "invariant"; 699 case TensorExp::Kind::kLoopVar: 700 return "index"; 701 case TensorExp::Kind::kSynZero: 702 return "0"; 703 // Unary operations. 704 case TensorExp::Kind::kAbsF: 705 case TensorExp::Kind::kAbsC: 706 case TensorExp::Kind::kAbsI: 707 return "abs"; 708 case TensorExp::Kind::kCeilF: 709 return "ceil"; 710 case TensorExp::Kind::kFloorF: 711 return "floor"; 712 case TensorExp::Kind::kSqrtF: 713 case TensorExp::Kind::kSqrtC: 714 return "sqrt"; 715 case TensorExp::Kind::kExpm1F: 716 case TensorExp::Kind::kExpm1C: 717 return "expm1"; 718 case TensorExp::Kind::kLog1pF: 719 case TensorExp::Kind::kLog1pC: 720 return "log1p"; 721 case TensorExp::Kind::kSinF: 722 case TensorExp::Kind::kSinC: 723 return "sin"; 724 case TensorExp::Kind::kTanhF: 725 case TensorExp::Kind::kTanhC: 726 return "tanh"; 727 case TensorExp::Kind::kNegF: 728 case TensorExp::Kind::kNegC: 729 case TensorExp::Kind::kNegI: 730 return "-"; 731 case TensorExp::Kind::kTruncF: 732 case TensorExp::Kind::kExtF: 733 case TensorExp::Kind::kCastFS: 734 case TensorExp::Kind::kCastFU: 735 case TensorExp::Kind::kCastSF: 736 case TensorExp::Kind::kCastUF: 737 case TensorExp::Kind::kCastS: 738 case TensorExp::Kind::kCastU: 739 case TensorExp::Kind::kCastIdx: 740 case TensorExp::Kind::kTruncI: 741 case TensorExp::Kind::kCIm: 742 return "complex.im"; 743 case TensorExp::Kind::kCRe: 744 return "complex.re"; 745 case TensorExp::Kind::kBitCast: 746 return "cast"; 747 case TensorExp::Kind::kBinaryBranch: 748 return "binary_branch"; 749 case TensorExp::Kind::kUnary: 750 return "unary"; 751 case TensorExp::Kind::kSelect: 752 return "select"; 753 // Binary operations. 754 case TensorExp::Kind::kMulF: 755 case TensorExp::Kind::kMulC: 756 case TensorExp::Kind::kMulI: 757 return "*"; 758 case TensorExp::Kind::kDivF: 759 case TensorExp::Kind::kDivC: 760 case TensorExp::Kind::kDivS: 761 case TensorExp::Kind::kDivU: 762 return "/"; 763 case TensorExp::Kind::kAddF: 764 case TensorExp::Kind::kAddC: 765 case TensorExp::Kind::kAddI: 766 return "+"; 767 case TensorExp::Kind::kSubF: 768 case TensorExp::Kind::kSubC: 769 case TensorExp::Kind::kSubI: 770 return "-"; 771 case TensorExp::Kind::kAndI: 772 return "&"; 773 case TensorExp::Kind::kOrI: 774 return "|"; 775 case TensorExp::Kind::kXorI: 776 return "^"; 777 case TensorExp::Kind::kShrS: 778 return "a>>"; 779 case TensorExp::Kind::kShrU: 780 return ">>"; 781 case TensorExp::Kind::kShlI: 782 return "<<"; 783 case TensorExp::Kind::kCmpF: 784 case TensorExp::Kind::kCmpI: 785 return "cmp"; 786 case TensorExp::Kind::kBinary: 787 return "binary"; 788 case TensorExp::Kind::kReduce: 789 return "reduce"; 790 case TensorExp::Kind::kDenseOp: 791 return "dense"; 792 } 793 llvm_unreachable("unexpected kind for symbol"); 794 } 795 796 void Merger::dumpExp(ExprId e) const { 797 const auto &expr = exp(e); 798 switch (expr.kind) { 799 // Leaf. 800 case TensorExp::Kind::kTensor: 801 if (expr.tensor == syntheticTensor) 802 llvm::dbgs() << "synthetic_"; 803 else if (expr.tensor == outTensor) 804 llvm::dbgs() << "output_"; 805 llvm::dbgs() << "tensor_" << expr.tensor; 806 break; 807 case TensorExp::Kind::kInvariant: 808 llvm::dbgs() << "invariant"; 809 break; 810 case TensorExp::Kind::kSynZero: 811 llvm::dbgs() << "0"; 812 break; 813 case TensorExp::Kind::kLoopVar: 814 llvm::dbgs() << "loopvar_" << expr.loop; 815 break; 816 // Unary operations. 817 case TensorExp::Kind::kAbsF: 818 case TensorExp::Kind::kAbsC: 819 case TensorExp::Kind::kAbsI: 820 case TensorExp::Kind::kCeilF: 821 case TensorExp::Kind::kFloorF: 822 case TensorExp::Kind::kSqrtF: 823 case TensorExp::Kind::kSqrtC: 824 case TensorExp::Kind::kExpm1F: 825 case TensorExp::Kind::kExpm1C: 826 case TensorExp::Kind::kLog1pF: 827 case TensorExp::Kind::kLog1pC: 828 case TensorExp::Kind::kSinF: 829 case TensorExp::Kind::kSinC: 830 case TensorExp::Kind::kTanhF: 831 case TensorExp::Kind::kTanhC: 832 case TensorExp::Kind::kNegF: 833 case TensorExp::Kind::kNegC: 834 case TensorExp::Kind::kNegI: 835 case TensorExp::Kind::kTruncF: 836 case TensorExp::Kind::kExtF: 837 case TensorExp::Kind::kCastFS: 838 case TensorExp::Kind::kCastFU: 839 case TensorExp::Kind::kCastSF: 840 case TensorExp::Kind::kCastUF: 841 case TensorExp::Kind::kCastS: 842 case TensorExp::Kind::kCastU: 843 case TensorExp::Kind::kCastIdx: 844 case TensorExp::Kind::kTruncI: 845 case TensorExp::Kind::kCIm: 846 case TensorExp::Kind::kCRe: 847 case TensorExp::Kind::kBitCast: 848 case TensorExp::Kind::kBinaryBranch: 849 case TensorExp::Kind::kUnary: 850 case TensorExp::Kind::kSelect: 851 llvm::dbgs() << kindToOpSymbol(expr.kind) << " "; 852 dumpExp(expr.children.e0); 853 break; 854 // Binary operations. 855 case TensorExp::Kind::kMulF: 856 case TensorExp::Kind::kMulC: 857 case TensorExp::Kind::kMulI: 858 case TensorExp::Kind::kDivF: 859 case TensorExp::Kind::kDivC: 860 case TensorExp::Kind::kDivS: 861 case TensorExp::Kind::kDivU: 862 case TensorExp::Kind::kAddF: 863 case TensorExp::Kind::kAddC: 864 case TensorExp::Kind::kAddI: 865 case TensorExp::Kind::kSubF: 866 case TensorExp::Kind::kSubC: 867 case TensorExp::Kind::kSubI: 868 case TensorExp::Kind::kAndI: 869 case TensorExp::Kind::kOrI: 870 case TensorExp::Kind::kXorI: 871 case TensorExp::Kind::kShrS: 872 case TensorExp::Kind::kShrU: 873 case TensorExp::Kind::kShlI: 874 case TensorExp::Kind::kCmpF: 875 case TensorExp::Kind::kCmpI: 876 case TensorExp::Kind::kBinary: 877 case TensorExp::Kind::kReduce: 878 case TensorExp::Kind::kDenseOp: 879 llvm::dbgs() << "("; 880 dumpExp(expr.children.e0); 881 llvm::dbgs() << " " << kindToOpSymbol(expr.kind); 882 if (expr.attr) 883 llvm::dbgs() << "{" << expr.attr << "}"; 884 if (expr.children.e1 != detail::kInvalidId) { 885 llvm::dbgs() << " "; 886 dumpExp(expr.children.e1); 887 llvm::dbgs() << ")"; 888 } else { 889 assert(expr.kind == TensorExp::Kind::kDenseOp); 890 } 891 break; 892 } 893 } 894 895 void Merger::dumpLat(LatPointId p) const { 896 const auto &point = lat(p); 897 llvm::dbgs() << "lat("; 898 dumpBits(point.bits); 899 llvm::dbgs() << " :"; 900 dumpBits(point.simple); 901 llvm::dbgs() << " : "; 902 dumpExp(point.exp); 903 llvm::dbgs() << " )\n"; 904 } 905 906 void Merger::dumpSet(LatSetId s) const { 907 const auto &ss = set(s); 908 llvm::dbgs() << "{ #" << ss.size() << "\n"; 909 for (const LatPointId p : ss) { 910 llvm::dbgs() << " "; 911 dumpLat(p); 912 } 913 llvm::dbgs() << "}\n"; 914 } 915 916 void Merger::dumpBits(const BitVector &bits) const { 917 for (TensorLoopId b = 0, be = bits.size(); b < be; b++) { 918 if (bits[b]) { 919 const TensorId t = tensor(b); 920 const LoopId i = loop(b); 921 const auto lt = lvlTypes[t][i]; 922 if (isLvlWithNonTrivialIdxExp(b)) 923 llvm::dbgs() << " DEP_" << t << "_" << i; 924 else 925 llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt); 926 } 927 } 928 } 929 930 #endif // NDEBUG 931 932 //===----------------------------------------------------------------------===// 933 // Builder methods. 934 //===----------------------------------------------------------------------===// 935 936 LatSetId Merger::buildLattices(ExprId e, LoopId i) { 937 // NOTE: The `expr` reference will be invalidated by recursive calls 938 // (and any other method that may add new expressions); therefore, the 939 // code below must make sure to copy fields of `expr` into local variables 940 // before making any recursive calls. 941 const auto &expr = exp(e); 942 const TensorExp::Kind kind = expr.kind; 943 switch (kind) { 944 // Leaf. 945 case TensorExp::Kind::kTensor: 946 case TensorExp::Kind::kInvariant: 947 case TensorExp::Kind::kSynZero: 948 case TensorExp::Kind::kLoopVar: { 949 // Either the loop-var is really used in the tensor expression, or it is 950 // set to the undefined loop-var in that level. An invariant expression, 951 // a proper index value, and a truly dynamic sparse output tensor are set 952 // to a synthetic tensor with undefined indices only to ensure the 953 // iteration space is not skipped as a result of their contents. 954 const LatSetId s = addSet(); 955 TensorId t = syntheticTensor; 956 if (kind == TensorExp::Kind::kTensor) { 957 t = expr.tensor; 958 if (hasSparseOut && t == outTensor) 959 t = syntheticTensor; 960 } 961 latSets[s].push_back(addLat(t, i, e)); 962 return s; 963 } 964 // Unary operations. 965 case TensorExp::Kind::kAbsF: 966 case TensorExp::Kind::kAbsC: 967 case TensorExp::Kind::kAbsI: 968 case TensorExp::Kind::kCeilF: 969 case TensorExp::Kind::kFloorF: 970 case TensorExp::Kind::kSqrtF: 971 case TensorExp::Kind::kSqrtC: 972 case TensorExp::Kind::kExpm1F: 973 case TensorExp::Kind::kExpm1C: 974 case TensorExp::Kind::kLog1pF: 975 case TensorExp::Kind::kLog1pC: 976 case TensorExp::Kind::kSinF: 977 case TensorExp::Kind::kSinC: 978 case TensorExp::Kind::kTanhF: 979 case TensorExp::Kind::kTanhC: 980 case TensorExp::Kind::kNegF: 981 case TensorExp::Kind::kNegC: 982 case TensorExp::Kind::kNegI: 983 case TensorExp::Kind::kTruncF: 984 case TensorExp::Kind::kExtF: 985 case TensorExp::Kind::kCastFS: 986 case TensorExp::Kind::kCastFU: 987 case TensorExp::Kind::kCastSF: 988 case TensorExp::Kind::kCastUF: 989 case TensorExp::Kind::kCastS: 990 case TensorExp::Kind::kCastU: 991 case TensorExp::Kind::kCastIdx: 992 case TensorExp::Kind::kTruncI: 993 case TensorExp::Kind::kCIm: 994 case TensorExp::Kind::kCRe: 995 case TensorExp::Kind::kBitCast: 996 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 997 // lattice set of the operand through the operator into a new set. 998 // 999 // -y|!y | y | 1000 // --+---+---+ 1001 // | 0 |-y | 1002 { 1003 const ExprId e0 = expr.children.e0; 1004 const Value v = expr.val; 1005 return mapSet(kind, buildLattices(e0, i), v); 1006 } 1007 case TensorExp::Kind::kBinaryBranch: 1008 case TensorExp::Kind::kSelect: 1009 // The left or right half of a binary operation which has already 1010 // been split into separate operations for each region. 1011 { 1012 const ExprId e0 = expr.children.e0; 1013 Operation *const op = expr.op; 1014 return mapSet(kind, buildLattices(e0, i), Value(), op); 1015 } 1016 case TensorExp::Kind::kUnary: 1017 // A custom unary operation. 1018 // 1019 // op y| !y | y | 1020 // ----+----------+------------+ 1021 // | absent() | present(y) | 1022 { 1023 const ExprId e0 = expr.children.e0; 1024 UnaryOp unop = cast<UnaryOp>(expr.op); 1025 const LatSetId child0 = buildLattices(e0, i); 1026 Region &absentRegion = unop.getAbsentRegion(); 1027 if (absentRegion.empty()) { 1028 // Simple mapping over existing values. 1029 return mapSet(kind, child0, Value(), unop); 1030 } 1031 // Use a disjunction with `unop` on the left and the absent value as an 1032 // invariant on the right. 1033 Block &absentBlock = absentRegion.front(); 1034 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); 1035 const Value absentVal = absentYield.getResult(); 1036 const ExprId rhs = addInvariantExp(absentVal); 1037 return disjSet(e, child0, buildLattices(rhs, i), unop); 1038 } 1039 // Binary operations. 1040 case TensorExp::Kind::kMulF: 1041 case TensorExp::Kind::kMulC: 1042 case TensorExp::Kind::kMulI: 1043 case TensorExp::Kind::kAndI: 1044 // A multiplicative operation only needs to be performed 1045 // for the conjunction of sparse iteration spaces. 1046 // 1047 // x*y|!y | y | 1048 // ---+---+---+ 1049 // !x | 0 | 0 | 1050 // x | 0 |x*y| 1051 // 1052 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. 1053 { 1054 const ExprId e0 = expr.children.e0; 1055 const ExprId e1 = expr.children.e1; 1056 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1057 } 1058 case TensorExp::Kind::kDivF: 1059 case TensorExp::Kind::kDivC: 1060 case TensorExp::Kind::kDivS: 1061 case TensorExp::Kind::kDivU: 1062 // A division is tricky, since 0/0, 0/c, c/0 all have 1063 // specific outcomes for floating-point and integers. 1064 // Thus, we need to traverse the full iteration space. 1065 // 1066 // x/y|!y | y | 1067 // ---+---+---+ 1068 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 1069 // x |x/0|x/y| INT: x/0=exception for any x 1070 // 1071 // TODO: for now we "fixed" this by only accepting x/c cases 1072 // during expression building, so that the conjunction 1073 // rules applies (viz. x/c = x*(1/c) as far as lattice 1074 // construction is concerned). 1075 { 1076 const ExprId e0 = expr.children.e0; 1077 const ExprId e1 = expr.children.e1; 1078 assert(!maybeZero(e1)); 1079 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1080 } 1081 case TensorExp::Kind::kAddF: 1082 case TensorExp::Kind::kAddC: 1083 case TensorExp::Kind::kAddI: 1084 case TensorExp::Kind::kSubF: 1085 case TensorExp::Kind::kSubC: 1086 case TensorExp::Kind::kSubI: 1087 case TensorExp::Kind::kOrI: 1088 case TensorExp::Kind::kXorI: 1089 // An additive operation needs to be performed 1090 // for the disjunction of sparse iteration spaces. 1091 // 1092 // x+y|!y | y | x-y|!y | y | 1093 // ---+---+---+ ---+---+---+ 1094 // !x | 0 | y | !x | 0 |-y | 1095 // x | x |x+y| x | x |x-y| 1096 { 1097 const ExprId e0 = expr.children.e0; 1098 const ExprId e1 = expr.children.e1; 1099 return disjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1100 } 1101 case TensorExp::Kind::kCmpF: 1102 case TensorExp::Kind::kCmpI: 1103 // A comparison operation needs to be performed 1104 // for the disjunction of sparse iteration spaces. 1105 // 1106 // x < y | !y | y | 1107 // -------+-------+-------+ 1108 // !x | 0 | 0 < y | 1109 // x | x < 0 | x < y | 1110 { 1111 const ExprId e0 = expr.children.e0; 1112 const ExprId e1 = expr.children.e1; 1113 return disjSetWithZero(e, buildLattices(e0, i), buildLattices(e1, i)); 1114 } 1115 case TensorExp::Kind::kShrS: 1116 case TensorExp::Kind::kShrU: 1117 case TensorExp::Kind::kShlI: 1118 // A shift operation by an invariant amount (viz. tensor expressions 1119 // can only occur at the left-hand-side of the operator) can be handled 1120 // with the conjunction rule. 1121 { 1122 const ExprId e0 = expr.children.e0; 1123 const ExprId e1 = expr.children.e1; 1124 assert(isInvariant(e1)); 1125 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1126 } 1127 case TensorExp::Kind::kBinary: 1128 // A custom binary operation. 1129 // 1130 // x op y| !y | y | 1131 // ------+---------+--------------+ 1132 // !x | empty | right(y) | 1133 // x | left(x) | overlap(x,y) | 1134 { 1135 const ExprId e0 = expr.children.e0; 1136 const ExprId e1 = expr.children.e1; 1137 BinaryOp binop = cast<BinaryOp>(expr.op); 1138 const LatSetId child0 = buildLattices(e0, i); 1139 const LatSetId child1 = buildLattices(e1, i); 1140 Region &leftRegion = binop.getLeftRegion(); 1141 Region &rightRegion = binop.getRightRegion(); 1142 // Left Region. 1143 Operation *leftYield = nullptr; 1144 if (!leftRegion.empty()) { 1145 Block &leftBlock = leftRegion.front(); 1146 leftYield = leftBlock.getTerminator(); 1147 } 1148 // Right Region. 1149 Operation *rightYield = nullptr; 1150 if (!rightRegion.empty()) { 1151 Block &rightBlock = rightRegion.front(); 1152 rightYield = rightBlock.getTerminator(); 1153 } 1154 bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty(); 1155 bool includeRight = binop.getRightIdentity() || !rightRegion.empty(); 1156 return combiSet(e, child0, child1, binop, includeLeft, 1157 TensorExp::Kind::kBinaryBranch, leftYield, includeRight, 1158 TensorExp::Kind::kBinaryBranch, rightYield); 1159 } 1160 case TensorExp::Kind::kReduce: 1161 // A custom reduce operation. 1162 { 1163 const ExprId e0 = expr.children.e0; 1164 const ExprId e1 = expr.children.e1; 1165 Operation *const op = expr.op; 1166 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op); 1167 } 1168 case TensorExp::Kind::kDenseOp: { 1169 // It does not really matter whether we use conjunctive/disjunctive set 1170 // here, as all the operands of kDenseOp must be dense, the disjunctive set 1171 // will be optimized into conjunctive set eventually. 1172 if (expr.children.e1 == detail::kInvalidId) { 1173 const ExprId e0 = expr.children.e0; 1174 Operation *const op = expr.op; 1175 return mapSet(kind, buildLattices(e0, i), Value(), op); 1176 } 1177 1178 const ExprId e0 = expr.children.e0; 1179 const ExprId e1 = expr.children.e1; 1180 Operation *const op = expr.op; 1181 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op); 1182 } 1183 } 1184 llvm_unreachable("unexpected expression kind"); 1185 } 1186 1187 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 1188 // Build the linalg semantics backward from yield. 1189 Operation *yield = op.getRegion().front().getTerminator(); 1190 assert(isa<linalg::YieldOp>(yield)); 1191 return buildTensorExp(op, yield->getOperand(0)).first; 1192 } 1193 1194 /// Only returns false if we are certain this is a nonzero. 1195 bool Merger::maybeZero(ExprId e) const { 1196 const auto &expr = exp(e); 1197 if (expr.kind == TensorExp::Kind::kInvariant) { 1198 if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) { 1199 ArrayAttr arrayAttr = c.getValue(); 1200 return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() && 1201 cast<FloatAttr>(arrayAttr[1]).getValue().isZero(); 1202 } 1203 if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>()) 1204 return c.value() == 0; 1205 if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>()) 1206 return c.value().isZero(); 1207 } 1208 return true; 1209 } 1210 1211 Type Merger::inferType(ExprId e, Value src) const { 1212 // Obtain the destination type from the cast node. 1213 Type dtp = exp(e).val.getType(); 1214 // Inspect source type. For vector types, apply the same 1215 // vectorization to the destination type. 1216 if (auto vtp = dyn_cast<VectorType>(src.getType())) 1217 return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims()); 1218 return dtp; 1219 } 1220 1221 /// Ensures that the sparsifier can generate code for expression. 1222 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) { 1223 // Arguments are always admissible. 1224 if (isa<BlockArgument>(v)) 1225 return true; 1226 // Accept index anywhere. 1227 Operation *def = v.getDefiningOp(); 1228 if (isa<linalg::IndexOp>(def)) 1229 return true; 1230 // Operation defined outside branch. 1231 if (def->getBlock() != block) 1232 return def->getBlock() != op->getBlock(); // invariant? 1233 // Operation defined within branch. Anything is accepted, 1234 // as long as all subexpressions are admissible. 1235 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) 1236 if (!isAdmissibleBranchExp(op, block, def->getOperand(i))) 1237 return false; 1238 return true; 1239 } 1240 1241 /// Ensures that the sparsifier can generate code for branch. 1242 static bool isAdmissibleBranch(Operation *op, Region ®ion) { 1243 if (region.empty()) 1244 return true; 1245 // Build the semi-ring branch semantics backward from yield. 1246 Operation *yield = region.front().getTerminator(); 1247 assert(isa<YieldOp>(yield)); 1248 return isAdmissibleBranchExp(op, ®ion.front(), yield->getOperand(0)); 1249 } 1250 1251 std::pair<std::optional<ExprId>, bool> 1252 Merger::buildTensorExp(linalg::GenericOp op, Value v) { 1253 // Recursion leaves. 1254 if (auto arg = dyn_cast<BlockArgument>(v)) { 1255 const TensorId tid = makeTensorId(arg.getArgNumber()); 1256 // Any argument of the generic op that is not marked as a scalar 1257 // argument is considered a tensor, indexed by the implicit loop 1258 // bounds. This includes rank-0 tensor arguments. 1259 if (arg.getOwner()->getParentOp() == op) { 1260 OpOperand &t = op->getOpOperand(tid); 1261 bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr; 1262 if (!op.isScalar(&t)) 1263 return {addTensorExp(tid), hasSpDep}; 1264 v = t.get(); // get scalar value 1265 } 1266 // Any other argument (marked as scalar argument for the generic op 1267 // or belonging to an enveloping op) is considered invariant. 1268 return {addInvariantExp(v), /*hasSpDep=*/false}; 1269 } 1270 // Something defined outside is invariant. 1271 Operation *def = v.getDefiningOp(); 1272 if (def->getBlock() != &op.getRegion().front()) 1273 return {addInvariantExp(v), /*hasSpDep=*/false}; 1274 // Construct index operations. 1275 if (def->getNumOperands() == 0) { 1276 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 1277 return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false}; 1278 } 1279 1280 // Construct unary operations if subexpression can be built. 1281 if (def->getNumOperands() == 1) { 1282 const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0)); 1283 if (x.has_value()) { 1284 const ExprId e = *x; 1285 if (isa<math::AbsFOp>(def)) 1286 return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep}; 1287 if (isa<complex::AbsOp>(def)) 1288 return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep}; 1289 if (isa<math::AbsIOp>(def)) 1290 return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep}; 1291 if (isa<math::CeilOp>(def)) 1292 return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep}; 1293 if (isa<math::FloorOp>(def)) 1294 return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep}; 1295 if (isa<math::SqrtOp>(def)) 1296 return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep}; 1297 if (isa<complex::SqrtOp>(def)) 1298 return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep}; 1299 if (isa<math::ExpM1Op>(def)) 1300 return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep}; 1301 if (isa<complex::Expm1Op>(def)) 1302 return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep}; 1303 if (isa<math::Log1pOp>(def)) 1304 return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep}; 1305 if (isa<complex::Log1pOp>(def)) 1306 return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep}; 1307 if (isa<math::SinOp>(def)) 1308 return {addExp(TensorExp::Kind::kSinF, e), hasSpDep}; 1309 if (isa<complex::SinOp>(def)) 1310 return {addExp(TensorExp::Kind::kSinC, e), hasSpDep}; 1311 if (isa<math::TanhOp>(def)) 1312 return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep}; 1313 if (isa<complex::TanhOp>(def)) 1314 return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep}; 1315 if (isa<arith::NegFOp>(def)) 1316 return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std 1317 if (isa<complex::NegOp>(def)) 1318 return {addExp(TensorExp::Kind::kNegC, e), hasSpDep}; 1319 if (isa<arith::TruncFOp>(def)) 1320 return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep}; 1321 if (isa<arith::ExtFOp>(def)) 1322 return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep}; 1323 if (isa<arith::FPToSIOp>(def)) 1324 return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep}; 1325 if (isa<arith::FPToUIOp>(def)) 1326 return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep}; 1327 if (isa<arith::SIToFPOp>(def)) 1328 return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep}; 1329 if (isa<arith::UIToFPOp>(def)) 1330 return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep}; 1331 if (isa<arith::ExtSIOp>(def)) 1332 return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep}; 1333 if (isa<arith::ExtUIOp>(def)) 1334 return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep}; 1335 if (isa<arith::IndexCastOp>(def)) 1336 return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep}; 1337 if (isa<arith::TruncIOp>(def)) 1338 return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep}; 1339 if (isa<complex::ImOp>(def)) 1340 return {addExp(TensorExp::Kind::kCIm, e), hasSpDep}; 1341 if (isa<complex::ReOp>(def)) 1342 return {addExp(TensorExp::Kind::kCRe, e), hasSpDep}; 1343 if (isa<arith::BitcastOp>(def)) 1344 return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep}; 1345 if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) { 1346 if (isAdmissibleBranch(unop, unop.getPresentRegion()) && 1347 isAdmissibleBranch(unop, unop.getAbsentRegion())) 1348 return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep}; 1349 } 1350 if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) { 1351 if (isAdmissibleBranch(selop, selop.getRegion())) 1352 return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep}; 1353 } 1354 } 1355 } 1356 // Construct binary operations if subexpressions can be built. 1357 // See buildLattices() for an explanation of rejecting certain 1358 // division and shift operations. 1359 if (def->getNumOperands() == 2) { 1360 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0)); 1361 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1)); 1362 bool hasSpDep = xDepSp || yDepSp; 1363 if (x.has_value() && y.has_value()) { 1364 const ExprId e0 = *x; 1365 const ExprId e1 = *y; 1366 if (isa<arith::MulFOp>(def)) 1367 return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep}; 1368 if (isa<complex::MulOp>(def)) 1369 return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep}; 1370 if (isa<arith::MulIOp>(def)) 1371 return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep}; 1372 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 1373 return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep}; 1374 if (isa<complex::DivOp>(def) && !maybeZero(e1)) 1375 return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep}; 1376 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 1377 return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep}; 1378 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 1379 return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep}; 1380 if (isa<arith::AddFOp>(def)) 1381 return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep}; 1382 if (isa<complex::AddOp>(def)) 1383 return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep}; 1384 if (isa<arith::AddIOp>(def)) 1385 return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep}; 1386 if (isa<arith::SubFOp>(def)) 1387 return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep}; 1388 if (isa<complex::SubOp>(def)) 1389 return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep}; 1390 if (isa<arith::SubIOp>(def)) 1391 return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep}; 1392 if (isa<arith::AndIOp>(def)) 1393 return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep}; 1394 if (isa<arith::OrIOp>(def)) 1395 return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep}; 1396 if (isa<arith::XOrIOp>(def)) 1397 return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep}; 1398 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 1399 return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep}; 1400 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 1401 return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep}; 1402 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 1403 return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep}; 1404 if (auto ci = dyn_cast<arith::CmpIOp>(def)) { 1405 if (ci.getPredicate() == arith::CmpIPredicate::eq && 1406 ci.getPredicate() == arith::CmpIPredicate::sle && 1407 ci.getPredicate() == arith::CmpIPredicate::sge && 1408 ci.getPredicate() == arith::CmpIPredicate::ule && 1409 ci.getPredicate() == arith::CmpIPredicate::uge) { 1410 // We can not sparsify comparison with equal, this is because 0 <= 0 1411 // yields true, and thus densifies the result. 1412 return {std::nullopt, false}; 1413 } 1414 1415 auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr, 1416 ci.getPredicateAttr()); 1417 return {e, hasSpDep}; 1418 } 1419 if (auto cf = dyn_cast<arith::CmpFOp>(def)) { 1420 if (cf.getPredicate() == arith::CmpFPredicate::OEQ && 1421 cf.getPredicate() == arith::CmpFPredicate::OGE && 1422 cf.getPredicate() == arith::CmpFPredicate::OLE && 1423 cf.getPredicate() == arith::CmpFPredicate::ONE && 1424 cf.getPredicate() == arith::CmpFPredicate::UEQ && 1425 cf.getPredicate() == arith::CmpFPredicate::UGE && 1426 cf.getPredicate() == arith::CmpFPredicate::ULE && 1427 cf.getPredicate() == arith::CmpFPredicate::ORD && 1428 cf.getPredicate() == arith::CmpFPredicate::UNO) { 1429 // We can not sparsify comparison with equal, this is because 0 <= 0 1430 // yields true, and thus densifies the result. 1431 return {std::nullopt, false}; 1432 } 1433 auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr, 1434 cf.getPredicateAttr()); 1435 return {e, hasSpDep}; 1436 } 1437 if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) { 1438 if (isAdmissibleBranch(binop, binop.getOverlapRegion()) && 1439 (binop.getLeftIdentity() || 1440 isAdmissibleBranch(binop, binop.getLeftRegion())) && 1441 (binop.getRightIdentity() || 1442 isAdmissibleBranch(binop, binop.getRightRegion()))) 1443 return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep}; 1444 } 1445 } 1446 } 1447 // Construct ternary operations if subexpressions can be built. 1448 if (def->getNumOperands() == 3) { 1449 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0)); 1450 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1)); 1451 const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2)); 1452 bool hasSpDep = xDepSp || yDepSp || zDepSp; 1453 if (x.has_value() && y.has_value() && z.has_value()) { 1454 const ExprId e0 = *x; 1455 const ExprId e1 = *y; 1456 if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) { 1457 if (isAdmissibleBranch(redop, redop.getRegion())) 1458 return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep}; 1459 } 1460 } 1461 } 1462 1463 // If we reach here, we are dealing with an operation that is not currently 1464 // sparsifiable. We can still generate code for it if all its operands only 1465 // have dense dependencies (i.e., all the values are loaded from dense 1466 // tensors). 1467 if (def->getNumResults() != 1) // only handle single result operation. 1468 return {std::nullopt, false}; 1469 1470 SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp; 1471 // Builds all the sub-expressions 1472 for (Value operand : def->getOperands()) 1473 subExp.push_back(buildTensorExp(op, operand)); 1474 1475 if (llvm::all_of(subExp, 1476 [](auto e) { return e.first.has_value() && !e.second; })) { 1477 // All the subexpressions can be built and has *no* sparse dependencies. 1478 if (subExp.size() == 2) { 1479 auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first, 1480 *subExp[1].first, def); 1481 return {e, false}; 1482 } 1483 if (subExp.size() == 1) { 1484 auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first, 1485 detail::kInvalidId, def); 1486 return {e, false}; 1487 } 1488 } 1489 // Cannot build. 1490 return {std::nullopt, false}; 1491 } 1492 1493 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, 1494 ValueRange vals) { 1495 // Make a clone of overlap region. 1496 Region tmpRegion; 1497 IRMapping mapper; 1498 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); 1499 Block &clonedBlock = tmpRegion.front(); 1500 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); 1501 // Merge cloned block and return yield value. 1502 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1503 rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals); 1504 Value val = clonedYield.getResult(); 1505 rewriter.eraseOp(clonedYield); 1506 rewriter.eraseOp(placeholder); 1507 return val; 1508 } 1509 1510 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, 1511 Operation *op, Value v0) { 1512 if (!v0) 1513 // Empty input value must be propagated. 1514 return Value(); 1515 UnaryOp unop = cast<UnaryOp>(op); 1516 Region &presentRegion = unop.getPresentRegion(); 1517 if (presentRegion.empty()) 1518 // Uninitialized Value() will be interpreted as missing data in the 1519 // output. 1520 return Value(); 1521 return insertYieldOp(rewriter, loc, presentRegion, {v0}); 1522 } 1523 1524 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, 1525 Operation *op, Value v0, Value v1) { 1526 if (!v0 || !v1) 1527 // Empty input values must be propagated. 1528 return Value(); 1529 BinaryOp binop = cast<BinaryOp>(op); 1530 Region &overlapRegion = binop.getOverlapRegion(); 1531 if (overlapRegion.empty()) 1532 // Uninitialized Value() will be interpreted as missing data in the 1533 // output. 1534 return Value(); 1535 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); 1536 } 1537 1538 Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, 1539 Value v1) const { 1540 const auto &expr = exp(e); 1541 switch (expr.kind) { 1542 // Leaf. 1543 case TensorExp::Kind::kTensor: 1544 case TensorExp::Kind::kInvariant: 1545 case TensorExp::Kind::kLoopVar: 1546 case TensorExp::Kind::kSynZero: 1547 llvm_unreachable("unexpected non-op"); 1548 // Unary operations. 1549 case TensorExp::Kind::kAbsF: 1550 return rewriter.create<math::AbsFOp>(loc, v0); 1551 case TensorExp::Kind::kAbsC: { 1552 auto type = cast<ComplexType>(v0.getType()); 1553 auto eltType = cast<FloatType>(type.getElementType()); 1554 return rewriter.create<complex::AbsOp>(loc, eltType, v0); 1555 } 1556 case TensorExp::Kind::kAbsI: 1557 return rewriter.create<math::AbsIOp>(loc, v0); 1558 case TensorExp::Kind::kCeilF: 1559 return rewriter.create<math::CeilOp>(loc, v0); 1560 case TensorExp::Kind::kFloorF: 1561 return rewriter.create<math::FloorOp>(loc, v0); 1562 case TensorExp::Kind::kSqrtF: 1563 return rewriter.create<math::SqrtOp>(loc, v0); 1564 case TensorExp::Kind::kSqrtC: 1565 return rewriter.create<complex::SqrtOp>(loc, v0); 1566 case TensorExp::Kind::kExpm1F: 1567 return rewriter.create<math::ExpM1Op>(loc, v0); 1568 case TensorExp::Kind::kExpm1C: 1569 return rewriter.create<complex::Expm1Op>(loc, v0); 1570 case TensorExp::Kind::kLog1pF: 1571 return rewriter.create<math::Log1pOp>(loc, v0); 1572 case TensorExp::Kind::kLog1pC: 1573 return rewriter.create<complex::Log1pOp>(loc, v0); 1574 case TensorExp::Kind::kSinF: 1575 return rewriter.create<math::SinOp>(loc, v0); 1576 case TensorExp::Kind::kSinC: 1577 return rewriter.create<complex::SinOp>(loc, v0); 1578 case TensorExp::Kind::kTanhF: 1579 return rewriter.create<math::TanhOp>(loc, v0); 1580 case TensorExp::Kind::kTanhC: 1581 return rewriter.create<complex::TanhOp>(loc, v0); 1582 case TensorExp::Kind::kNegF: 1583 return rewriter.create<arith::NegFOp>(loc, v0); 1584 case TensorExp::Kind::kNegC: 1585 return rewriter.create<complex::NegOp>(loc, v0); 1586 case TensorExp::Kind::kNegI: // no negi in std 1587 return rewriter.create<arith::SubIOp>( 1588 loc, 1589 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 1590 rewriter.getZeroAttr(v0.getType())), 1591 v0); 1592 case TensorExp::Kind::kTruncF: 1593 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 1594 case TensorExp::Kind::kExtF: 1595 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 1596 case TensorExp::Kind::kCastFS: 1597 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 1598 case TensorExp::Kind::kCastFU: 1599 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 1600 case TensorExp::Kind::kCastSF: 1601 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 1602 case TensorExp::Kind::kCastUF: 1603 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 1604 case TensorExp::Kind::kCastS: 1605 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 1606 case TensorExp::Kind::kCastU: 1607 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 1608 case TensorExp::Kind::kCastIdx: 1609 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 1610 case TensorExp::Kind::kTruncI: 1611 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 1612 case TensorExp::Kind::kCIm: { 1613 auto type = cast<ComplexType>(v0.getType()); 1614 auto eltType = cast<FloatType>(type.getElementType()); 1615 return rewriter.create<complex::ImOp>(loc, eltType, v0); 1616 } 1617 case TensorExp::Kind::kCRe: { 1618 auto type = cast<ComplexType>(v0.getType()); 1619 auto eltType = cast<FloatType>(type.getElementType()); 1620 return rewriter.create<complex::ReOp>(loc, eltType, v0); 1621 } 1622 case TensorExp::Kind::kBitCast: 1623 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 1624 // Binary operations. 1625 case TensorExp::Kind::kMulF: 1626 return rewriter.create<arith::MulFOp>(loc, v0, v1); 1627 case TensorExp::Kind::kMulC: 1628 return rewriter.create<complex::MulOp>(loc, v0, v1); 1629 case TensorExp::Kind::kMulI: 1630 return rewriter.create<arith::MulIOp>(loc, v0, v1); 1631 case TensorExp::Kind::kDivF: 1632 return rewriter.create<arith::DivFOp>(loc, v0, v1); 1633 case TensorExp::Kind::kDivC: 1634 return rewriter.create<complex::DivOp>(loc, v0, v1); 1635 case TensorExp::Kind::kDivS: 1636 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 1637 case TensorExp::Kind::kDivU: 1638 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 1639 case TensorExp::Kind::kAddF: 1640 return rewriter.create<arith::AddFOp>(loc, v0, v1); 1641 case TensorExp::Kind::kAddC: 1642 return rewriter.create<complex::AddOp>(loc, v0, v1); 1643 case TensorExp::Kind::kAddI: 1644 return rewriter.create<arith::AddIOp>(loc, v0, v1); 1645 case TensorExp::Kind::kSubF: 1646 return rewriter.create<arith::SubFOp>(loc, v0, v1); 1647 case TensorExp::Kind::kSubC: 1648 return rewriter.create<complex::SubOp>(loc, v0, v1); 1649 case TensorExp::Kind::kSubI: 1650 return rewriter.create<arith::SubIOp>(loc, v0, v1); 1651 case TensorExp::Kind::kAndI: 1652 return rewriter.create<arith::AndIOp>(loc, v0, v1); 1653 case TensorExp::Kind::kOrI: 1654 return rewriter.create<arith::OrIOp>(loc, v0, v1); 1655 case TensorExp::Kind::kXorI: 1656 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 1657 case TensorExp::Kind::kShrS: 1658 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 1659 case TensorExp::Kind::kShrU: 1660 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 1661 case TensorExp::Kind::kShlI: 1662 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 1663 case TensorExp::Kind::kCmpI: { 1664 auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr); 1665 return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1); 1666 } 1667 case TensorExp::Kind::kCmpF: { 1668 auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr); 1669 return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1); 1670 } 1671 case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic. 1672 return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(), 1673 {v0}); 1674 case TensorExp::Kind::kUnary: 1675 return buildUnaryPresent(rewriter, loc, expr.op, v0); 1676 case TensorExp::Kind::kSelect: 1677 return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(), 1678 {v0}); 1679 case TensorExp::Kind::kBinary: 1680 return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1); 1681 case TensorExp::Kind::kReduce: { 1682 ReduceOp redOp = cast<ReduceOp>(expr.op); 1683 return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1}); 1684 } 1685 case TensorExp::Kind::kDenseOp: { 1686 Operation *actualOp = expr.op; 1687 IRMapping mapping; 1688 mapping.map(actualOp->getOperand(0), v0); 1689 if (actualOp->getNumOperands() == 2) 1690 mapping.map(actualOp->getOperand(1), v1); 1691 return rewriter.clone(*actualOp, mapping)->getResult(0); 1692 } 1693 } 1694 llvm_unreachable("unexpected expression kind in build"); 1695 } 1696 1697 } // namespace sparse_tensor 1698 } // namespace mlir 1699