1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arith/IR/Arith.h" 11 #include "mlir/Dialect/Complex/IR/Complex.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 15 #include "mlir/IR/Operation.h" 16 #include "llvm/Support/Debug.h" 17 #include <optional> 18 19 namespace mlir { 20 namespace sparse_tensor { 21 22 enum class ExpArity { 23 kNullary, 24 kUnary, 25 kBinary, 26 }; 27 28 static ExpArity getExpArity(TensorExp::Kind k) { 29 switch (k) { 30 // Leaf. 31 case TensorExp::Kind::kTensor: 32 case TensorExp::Kind::kInvariant: 33 case TensorExp::Kind::kLoopVar: 34 case TensorExp::Kind::kSynZero: 35 return ExpArity::kNullary; 36 case TensorExp::Kind::kAbsF: 37 case TensorExp::Kind::kAbsC: 38 case TensorExp::Kind::kAbsI: 39 case TensorExp::Kind::kCeilF: 40 case TensorExp::Kind::kFloorF: 41 case TensorExp::Kind::kSqrtF: 42 case TensorExp::Kind::kSqrtC: 43 case TensorExp::Kind::kExpm1F: 44 case TensorExp::Kind::kExpm1C: 45 case TensorExp::Kind::kLog1pF: 46 case TensorExp::Kind::kLog1pC: 47 case TensorExp::Kind::kSinF: 48 case TensorExp::Kind::kSinC: 49 case TensorExp::Kind::kTanhF: 50 case TensorExp::Kind::kTanhC: 51 case TensorExp::Kind::kTruncF: 52 case TensorExp::Kind::kExtF: 53 case TensorExp::Kind::kCastFS: 54 case TensorExp::Kind::kCastFU: 55 case TensorExp::Kind::kCastSF: 56 case TensorExp::Kind::kCastUF: 57 case TensorExp::Kind::kCastS: 58 case TensorExp::Kind::kCastU: 59 case TensorExp::Kind::kCastIdx: 60 case TensorExp::Kind::kTruncI: 61 case TensorExp::Kind::kCIm: 62 case TensorExp::Kind::kCRe: 63 case TensorExp::Kind::kBitCast: 64 case TensorExp::Kind::kBinaryBranch: 65 case TensorExp::Kind::kUnary: 66 case TensorExp::Kind::kSelect: 67 case TensorExp::Kind::kNegF: 68 case TensorExp::Kind::kNegC: 69 case TensorExp::Kind::kNegI: 70 return ExpArity::kUnary; 71 // Binary operations. 72 case TensorExp::Kind::kDivF: 73 case TensorExp::Kind::kDivC: 74 case TensorExp::Kind::kDivS: 75 case TensorExp::Kind::kDivU: 76 case TensorExp::Kind::kShrS: 77 case TensorExp::Kind::kShrU: 78 case TensorExp::Kind::kShlI: 79 case TensorExp::Kind::kMulF: 80 case TensorExp::Kind::kMulC: 81 case TensorExp::Kind::kMulI: 82 case TensorExp::Kind::kAndI: 83 case TensorExp::Kind::kAddF: 84 case TensorExp::Kind::kAddC: 85 case TensorExp::Kind::kAddI: 86 case TensorExp::Kind::kOrI: 87 case TensorExp::Kind::kXorI: 88 case TensorExp::Kind::kBinary: 89 case TensorExp::Kind::kReduce: 90 case TensorExp::Kind::kSubF: 91 case TensorExp::Kind::kSubC: 92 case TensorExp::Kind::kSubI: 93 case TensorExp::Kind::kCmpF: 94 case TensorExp::Kind::kCmpI: 95 case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands 96 return ExpArity::kBinary; 97 } 98 llvm_unreachable("unexpected kind"); 99 } 100 101 //===----------------------------------------------------------------------===// 102 // Constructors. 103 //===----------------------------------------------------------------------===// 104 105 TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v, 106 Operation *o, Attribute a) 107 : kind(k), val(v), op(o) { 108 switch (kind) { 109 // Leaf. 110 case TensorExp::Kind::kTensor: 111 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 112 tensor = x; 113 return; 114 case TensorExp::Kind::kSynZero: 115 assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o); 116 return; 117 case TensorExp::Kind::kInvariant: 118 assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o); 119 return; 120 case TensorExp::Kind::kLoopVar: 121 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 122 loop = x; 123 return; 124 // Unary operations. 125 case TensorExp::Kind::kAbsF: 126 case TensorExp::Kind::kAbsC: 127 case TensorExp::Kind::kAbsI: 128 case TensorExp::Kind::kCeilF: 129 case TensorExp::Kind::kFloorF: 130 case TensorExp::Kind::kSqrtF: 131 case TensorExp::Kind::kSqrtC: 132 case TensorExp::Kind::kExpm1F: 133 case TensorExp::Kind::kExpm1C: 134 case TensorExp::Kind::kLog1pF: 135 case TensorExp::Kind::kLog1pC: 136 case TensorExp::Kind::kSinF: 137 case TensorExp::Kind::kSinC: 138 case TensorExp::Kind::kTanhF: 139 case TensorExp::Kind::kTanhC: 140 case TensorExp::Kind::kNegF: 141 case TensorExp::Kind::kNegC: 142 case TensorExp::Kind::kNegI: 143 case TensorExp::Kind::kCIm: 144 case TensorExp::Kind::kCRe: 145 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 146 children.e0 = x; 147 children.e1 = y; 148 return; 149 case TensorExp::Kind::kTruncF: 150 case TensorExp::Kind::kExtF: 151 case TensorExp::Kind::kCastFS: 152 case TensorExp::Kind::kCastFU: 153 case TensorExp::Kind::kCastSF: 154 case TensorExp::Kind::kCastUF: 155 case TensorExp::Kind::kCastS: 156 case TensorExp::Kind::kCastU: 157 case TensorExp::Kind::kCastIdx: 158 case TensorExp::Kind::kTruncI: 159 case TensorExp::Kind::kBitCast: 160 assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o); 161 children.e0 = x; 162 children.e1 = y; 163 return; 164 case TensorExp::Kind::kBinaryBranch: 165 case TensorExp::Kind::kSelect: 166 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o); 167 children.e0 = x; 168 children.e1 = y; 169 return; 170 case TensorExp::Kind::kUnary: 171 // No assertion on y can be made, as the branching paths involve both 172 // a unary (`mapSet`) and binary (`disjSet`) pathway. 173 assert(x != detail::kInvalidId && !v && o); 174 children.e0 = x; 175 children.e1 = y; 176 return; 177 // Binary operations. 178 case TensorExp::Kind::kMulF: 179 case TensorExp::Kind::kMulC: 180 case TensorExp::Kind::kMulI: 181 case TensorExp::Kind::kDivF: 182 case TensorExp::Kind::kDivC: 183 case TensorExp::Kind::kDivS: 184 case TensorExp::Kind::kDivU: 185 case TensorExp::Kind::kAddF: 186 case TensorExp::Kind::kAddC: 187 case TensorExp::Kind::kAddI: 188 case TensorExp::Kind::kSubF: 189 case TensorExp::Kind::kSubC: 190 case TensorExp::Kind::kSubI: 191 case TensorExp::Kind::kAndI: 192 case TensorExp::Kind::kOrI: 193 case TensorExp::Kind::kXorI: 194 case TensorExp::Kind::kShrS: 195 case TensorExp::Kind::kShrU: 196 case TensorExp::Kind::kShlI: 197 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); 198 children.e0 = x; 199 children.e1 = y; 200 return; 201 case TensorExp::Kind::kCmpF: 202 case TensorExp::Kind::kCmpI: 203 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); 204 attr = a; 205 children.e0 = x; 206 children.e1 = y; 207 return; 208 case TensorExp::Kind::kBinary: 209 case TensorExp::Kind::kReduce: 210 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o); 211 children.e0 = x; 212 children.e1 = y; 213 return; 214 case TensorExp::Kind::kDenseOp: 215 assert(x != detail::kInvalidId && !v && o); 216 children.e0 = x; 217 children.e1 = y; 218 return; 219 } 220 llvm_unreachable("unexpected kind"); 221 } 222 223 Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops, 224 unsigned numFilterLoops, unsigned maxLvlRank) 225 : outTensor(numInputOutputTensors - 1), 226 syntheticTensor(numInputOutputTensors), 227 numTensors(numInputOutputTensors + 1), numNativeLoops(numNativeLoops), 228 numLoops(numNativeLoops + numFilterLoops), hasSparseOut(false), 229 lvlTypes(numTensors, 230 std::vector<DimLevelType>(numLoops, DimLevelType::Undef)), 231 loopToLvl(numTensors, 232 std::vector<std::optional<Level>>(numLoops, std::nullopt)), 233 lvlToLoop(numTensors, 234 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)), 235 loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlDLTPair>>( 236 numTensors, std::nullopt)), 237 levelToDependentLoop(numTensors, 238 std::vector<std::vector<LoopCoeffPair>>( 239 maxLvlRank, std::vector<LoopCoeffPair>())), 240 loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {} 241 242 //===----------------------------------------------------------------------===// 243 // Lattice methods. 244 //===----------------------------------------------------------------------===// 245 246 ExprId Merger::addTensorExp(TensorId t) { 247 assert(isValidTensorId(t)); 248 const ExprId eNew(tensorExps.size()); 249 tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId, 250 Value(), nullptr, nullptr); 251 return eNew; 252 } 253 254 ExprId Merger::addLoopVarExp(LoopId i) { 255 assert(isValidLoopId(i)); 256 const ExprId eNew(tensorExps.size()); 257 tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId, 258 Value(), nullptr, nullptr); 259 return eNew; 260 } 261 262 ExprId Merger::addInvariantExp(Value v) { 263 const ExprId eNew(tensorExps.size()); 264 tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId, 265 detail::kInvalidId, v, nullptr, nullptr); 266 return eNew; 267 } 268 269 ExprId Merger::addSynZeroExp() { 270 const ExprId eNew(tensorExps.size()); 271 tensorExps.emplace_back(TensorExp::Kind::kSynZero, detail::kInvalidId, 272 detail::kInvalidId, Value(), nullptr, nullptr); 273 return eNew; 274 } 275 276 ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op, 277 Attribute attr) { 278 assert(k > TensorExp::Kind::kLoopVar); 279 const ExprId eNew(tensorExps.size()); 280 tensorExps.emplace_back(k, e0, e1, Value(), op, attr); 281 return eNew; 282 } 283 284 ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op, 285 Attribute attr) { 286 assert(k > TensorExp::Kind::kLoopVar); 287 const ExprId eNew(tensorExps.size()); 288 tensorExps.emplace_back(k, e, detail::kInvalidId, v, op, attr); 289 return eNew; 290 } 291 292 LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) { 293 const LatPointId pNew(latPoints.size()); 294 const unsigned size = numLoops * numTensors; 295 const TensorLoopId b = makeTensorLoopId(t, i); 296 latPoints.emplace_back(size, e); 297 latPoints[pNew].bits.set(b); 298 return pNew; 299 } 300 301 LatPointId Merger::addLat(const BitVector &bits, ExprId e) { 302 assert(bits.size() == numLoops * numTensors); 303 const LatPointId pNew(latPoints.size()); 304 latPoints.emplace_back(bits, e); 305 return pNew; 306 } 307 308 LatSetId Merger::addSet() { 309 const LatSetId sNew(latSets.size()); 310 latSets.emplace_back(); 311 return sNew; 312 } 313 314 LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1, 315 Operation *op) { 316 TensorExp::Kind kind = exp(e).kind; 317 Attribute attr = exp(e).attr; 318 const LatPointId pNew(latPoints.size()); 319 const auto &point0 = lat(p0); 320 const auto &point1 = lat(p1); 321 BitVector bits(point0.bits); 322 bits |= point1.bits; 323 const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr); 324 latPoints.emplace_back(bits, ne); 325 return pNew; 326 } 327 328 LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { 329 const LatSetId sNew = addSet(); 330 auto &setNew = latSets[sNew]; 331 for (const LatPointId p0 : set(s0)) 332 for (const LatPointId p1 : set(s1)) 333 setNew.push_back(conjLat(e, p0, p1, op)); 334 return sNew; 335 } 336 337 LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { 338 const LatSetId sNew = conjSet(e, s0, s1, op); 339 TensorExp::Kind kind = exp(e).kind; 340 341 // Followed by all in s0. 342 latSets[sNew].append(latSets[s0]); 343 // Map binary 0-y to unary -y. 344 // TODO: move this if-else logic into buildLattices 345 if (kind == TensorExp::Kind::kSubF) 346 s1 = mapSet(TensorExp::Kind::kNegF, s1); 347 else if (kind == TensorExp::Kind::kSubC) 348 s1 = mapSet(TensorExp::Kind::kNegC, s1); 349 else if (kind == TensorExp::Kind::kSubI) 350 s1 = mapSet(TensorExp::Kind::kNegI, s1); 351 // Followed by all in s1. 352 latSets[sNew].append(latSets[s1]); 353 return sNew; 354 } 355 356 LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) { 357 assert(exp(e).kind == TensorExp::Kind::kCmpI || 358 exp(e).kind == TensorExp::Kind::kCmpF); 359 const LatSetId sNew = conjSet(e, s0, s1, nullptr); 360 361 ExprId e0 = exp(e).children.e0; 362 ExprId e1 = exp(e).children.e1; 363 if (exp(e0).kind == TensorExp::Kind::kSynZero || 364 exp(e1).kind == TensorExp::Kind::kSynZero) { 365 // lhs and rhs can't be synthetic zero at the same time. 366 assert(exp(e0).kind != exp(e1).kind); 367 // If one of the operands has already been assigned to zero (the 368 // element is absent in the corresponding operand), then we do not 369 // need to build disjunctive set for it. 370 return sNew; 371 } 372 373 auto lhsSet = mapBinWithSynZeroSet(e, s0, false); 374 auto rhsSet = mapBinWithSynZeroSet(e, s1, true); 375 latSets[sNew].append(latSets[lhsSet]); 376 latSets[sNew].append(latSets[rhsSet]); 377 return sNew; 378 } 379 380 LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, 381 bool includeLeft, TensorExp::Kind ltrans, 382 Operation *opleft, bool includeRight, 383 TensorExp::Kind rtrans, Operation *opright) { 384 const LatSetId sNew = conjSet(e, s0, s1, orig); 385 // Left Region. 386 if (includeLeft) { 387 if (opleft) 388 s0 = mapSet(ltrans, s0, Value(), opleft); 389 latSets[sNew].append(latSets[s0]); 390 } 391 // Right Region. 392 if (includeRight) { 393 if (opright) 394 s1 = mapSet(rtrans, s1, Value(), opright); 395 latSets[sNew].append(latSets[s1]); 396 } 397 return sNew; 398 } 399 400 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v, 401 Operation *op) { 402 assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) || 403 TensorExp::Kind::kDenseOp == kind); 404 const LatSetId sNew = addSet(); 405 auto &setNew = latSets[sNew]; 406 for (const LatPointId p : set(s0)) { 407 const auto &point = latPoints[p]; 408 setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op))); 409 } 410 return sNew; 411 } 412 413 LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) { 414 TensorExp::Kind kind = exp(e).kind; 415 Attribute a = exp(e).attr; 416 assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI); 417 // Must be a binary operation. 418 const LatSetId sNew = addSet(); 419 auto &setNew = latSets[sNew]; 420 const ExprId zeroExp = addSynZeroExp(); 421 for (const LatPointId p : set(s0)) { 422 const auto &point = latPoints[p]; 423 ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a) 424 : addExp(kind, point.exp, zeroExp, nullptr, a); 425 setNew.push_back(addLat(point.bits, newExp)); 426 } 427 return sNew; 428 } 429 430 LatSetId Merger::optimizeSet(LatSetId s0) { 431 const LatSetId sNew = addSet(); 432 auto &setNew = latSets[sNew]; 433 const auto &set0 = set(s0); 434 assert(!set0.empty()); 435 const LatPointId p0 = set0[0]; 436 for (const LatPointId p1 : set0) { 437 bool add = true; 438 if (p0 != p1) { 439 // Check whether this is a straightforward copy. 440 if (expIsTensor(latPoints[p1].exp, outTensor)) 441 continue; 442 // Check whether this conjunction is already covered. 443 for (const LatPointId p2 : setNew) { 444 assert(!latGT(p1, p2)); // Lj => Li would be bad 445 if (onlyDenseDiff(p2, p1)) { 446 add = false; 447 break; 448 } 449 } 450 assert(!add || latGT(p0, p1)); 451 } 452 if (add) 453 setNew.push_back(p1); 454 } 455 for (const LatPointId p : setNew) 456 latPoints[p].simple = simplifyCond(sNew, p); 457 return sNew; 458 } 459 460 BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { 461 // First determine if this lattice point is a *singleton*, i.e., 462 // the last point in a lattice, no other is less than this one. 463 bool isSingleton = true; 464 for (const LatPointId p1 : set(s0)) { 465 if (p0 != p1 && latGT(p0, p1)) { 466 isSingleton = false; 467 break; 468 } 469 } 470 471 BitVector simple(latPoints[p0].bits); 472 bool reset = isSingleton && hasAnySparse(simple); 473 const TensorLoopId be = simple.size(); 474 TensorLoopId offset = 0; // relative to the end 475 if (!reset) 476 // Starts resetting from a dense level, so that the first bit (if kept) 477 // is not undefined level-type. 478 for (unsigned b = 0; b < be; b++) { 479 if (simple[b] && isDenseDLT(getLvlType(TensorLoopId{b}))) { 480 offset = be - b - 1; // relative to the end 481 break; 482 } 483 } 484 485 // Now apply the two basic rules. We also iterate the bits reversely to always 486 // keep the rightmost bit (which could possibly be a synthetic tensor). 487 for (unsigned b = be - 1 - offset, i = 0; i < be; 488 b = b == 0 ? be - 1 : b - 1, i++) { 489 // Slice on dense level has `locate` property as well, and can be optimized. 490 if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) { 491 const auto dlt = getLvlType(b); 492 if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) && 493 !isLooseCompressedDLT(dlt)) { 494 if (reset) 495 simple.reset(b); 496 reset = true; 497 } 498 } 499 } 500 return simple; 501 } 502 503 bool Merger::latGT(LatPointId i, LatPointId j) const { 504 const BitVector &bitsi = lat(i).bits; 505 const BitVector &bitsj = lat(j).bits; 506 assert(bitsi.size() == bitsj.size()); 507 if (bitsi.count() > bitsj.count()) { 508 for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) 509 if (bitsj[b] && !bitsi[b]) 510 return false; 511 return true; 512 } 513 return false; 514 } 515 516 bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { 517 BitVector tmp(latPoints[j].bits); 518 tmp ^= latPoints[i].bits; 519 return !hasAnySparse(tmp); 520 } 521 522 bool Merger::expContainsTensor(ExprId e, TensorId t) const { 523 const auto &expr = exp(e); 524 // First we check `expIsTensor`. 525 if (expr.kind == TensorExp::Kind::kTensor) 526 return expr.tensor == t; 527 528 switch (getExpArity(expr.kind)) { 529 case ExpArity::kNullary: 530 return false; 531 case ExpArity::kUnary: { 532 const ExprId e0 = expr.children.e0; 533 return expContainsTensor(e0, t); 534 } 535 case ExpArity::kBinary: { 536 const ExprId e0 = expr.children.e0; 537 const ExprId e1 = expr.children.e1; 538 return expContainsTensor(e0, t) || expContainsTensor(e1, t); 539 } 540 } 541 llvm_unreachable("unexpected arity"); 542 } 543 544 bool Merger::hasNegateOnOut(ExprId e) const { 545 const auto &expr = exp(e); 546 switch (expr.kind) { 547 case TensorExp::Kind::kNegF: 548 case TensorExp::Kind::kNegC: 549 case TensorExp::Kind::kNegI: 550 return expContainsTensor(expr.children.e0, outTensor); 551 case TensorExp::Kind::kSubF: 552 case TensorExp::Kind::kSubC: 553 case TensorExp::Kind::kSubI: 554 return expContainsTensor(expr.children.e1, outTensor) || 555 hasNegateOnOut(expr.children.e0); 556 case TensorExp::Kind::kDenseOp: { 557 bool lhsNeg = hasNegateOnOut(expr.children.e0); 558 if (!lhsNeg && expr.children.e1 != detail::kInvalidId) 559 return hasNegateOnOut(expr.children.e1); 560 return lhsNeg; 561 } 562 default: { 563 switch (getExpArity(expr.kind)) { 564 case ExpArity::kNullary: 565 return false; 566 case ExpArity::kUnary: 567 return hasNegateOnOut(expr.children.e0); 568 case ExpArity::kBinary: 569 return hasNegateOnOut(expr.children.e0) || 570 hasNegateOnOut(expr.children.e1); 571 } 572 } 573 } 574 llvm_unreachable("unexpected kind"); 575 } 576 577 bool Merger::isSingleCondition(TensorId t, ExprId e) const { 578 assert(isValidTensorId(t)); 579 const auto &expr = exp(e); 580 switch (expr.kind) { 581 // Leaf. 582 case TensorExp::Kind::kTensor: 583 return expr.tensor == t; 584 case TensorExp::Kind::kInvariant: 585 case TensorExp::Kind::kLoopVar: 586 case TensorExp::Kind::kSynZero: 587 return false; 588 // Unary operations. 589 case TensorExp::Kind::kAbsF: 590 case TensorExp::Kind::kAbsC: 591 case TensorExp::Kind::kAbsI: 592 case TensorExp::Kind::kCeilF: 593 case TensorExp::Kind::kFloorF: 594 case TensorExp::Kind::kSqrtF: 595 case TensorExp::Kind::kSqrtC: 596 case TensorExp::Kind::kExpm1F: 597 case TensorExp::Kind::kExpm1C: 598 case TensorExp::Kind::kLog1pF: 599 case TensorExp::Kind::kLog1pC: 600 case TensorExp::Kind::kSinF: 601 case TensorExp::Kind::kSinC: 602 case TensorExp::Kind::kTanhF: 603 case TensorExp::Kind::kTanhC: 604 case TensorExp::Kind::kNegF: 605 case TensorExp::Kind::kNegC: 606 case TensorExp::Kind::kNegI: 607 case TensorExp::Kind::kTruncF: 608 case TensorExp::Kind::kExtF: 609 case TensorExp::Kind::kCastFS: 610 case TensorExp::Kind::kCastFU: 611 case TensorExp::Kind::kCastSF: 612 case TensorExp::Kind::kCastUF: 613 case TensorExp::Kind::kCastS: 614 case TensorExp::Kind::kCastU: 615 case TensorExp::Kind::kCastIdx: 616 case TensorExp::Kind::kTruncI: 617 case TensorExp::Kind::kCIm: 618 case TensorExp::Kind::kCRe: 619 case TensorExp::Kind::kBitCast: 620 case TensorExp::Kind::kUnary: 621 return isSingleCondition(t, expr.children.e0); 622 case TensorExp::Kind::kBinaryBranch: 623 case TensorExp::Kind::kSelect: 624 return false; 625 // Binary operations. 626 case TensorExp::Kind::kDivF: // note: x / c only 627 case TensorExp::Kind::kDivC: 628 case TensorExp::Kind::kDivS: 629 case TensorExp::Kind::kDivU: 630 assert(!maybeZero(expr.children.e1)); 631 return isSingleCondition(t, expr.children.e0); 632 case TensorExp::Kind::kShrS: // note: x >> inv only 633 case TensorExp::Kind::kShrU: 634 case TensorExp::Kind::kShlI: 635 assert(isInvariant(expr.children.e1)); 636 return isSingleCondition(t, expr.children.e0); 637 case TensorExp::Kind::kMulF: 638 case TensorExp::Kind::kMulC: 639 case TensorExp::Kind::kMulI: 640 case TensorExp::Kind::kAndI: 641 case TensorExp::Kind::kReduce: 642 if (isSingleCondition(t, expr.children.e0)) 643 return isSingleCondition(t, expr.children.e1) || 644 isInvariant(expr.children.e1); 645 if (isSingleCondition(t, expr.children.e1)) 646 return isInvariant(expr.children.e0); 647 return false; 648 case TensorExp::Kind::kAddF: 649 case TensorExp::Kind::kAddC: 650 case TensorExp::Kind::kAddI: 651 return isSingleCondition(t, expr.children.e0) && 652 isSingleCondition(t, expr.children.e1); 653 case TensorExp::Kind::kSubF: 654 case TensorExp::Kind::kSubC: 655 case TensorExp::Kind::kSubI: 656 case TensorExp::Kind::kOrI: 657 case TensorExp::Kind::kXorI: 658 case TensorExp::Kind::kCmpF: 659 case TensorExp::Kind::kCmpI: 660 case TensorExp::Kind::kBinary: 661 return false; 662 case TensorExp::Kind::kDenseOp: 663 // Since Merger guarantees all the operands of the kDenseOp to be dense, the 664 // operation must be single-condition. 665 return true; 666 } 667 llvm_unreachable("unexpected kind"); 668 } 669 670 bool Merger::hasAnySparse(const BitVector &bits) const { 671 for (TensorLoopId b : bits.set_bits()) { 672 const auto dlt = getLvlType(b); 673 if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || 674 isLooseCompressedDLT(dlt)) 675 return true; 676 } 677 return hasSparseIdxReduction(bits); 678 } 679 680 bool Merger::hasSparseIdxReduction(const BitVector &bits) const { 681 for (TensorLoopId b : bits.set_bits()) 682 if (isSparseLvlWithNonTrivialIdxExp(b)) 683 return true; 684 return false; 685 } 686 687 #ifndef NDEBUG 688 689 //===----------------------------------------------------------------------===// 690 // Print methods (for debugging). 691 //===----------------------------------------------------------------------===// 692 693 static const char *kindToOpSymbol(TensorExp::Kind kind) { 694 switch (kind) { 695 // Leaf. 696 case TensorExp::Kind::kTensor: 697 return "tensor"; 698 case TensorExp::Kind::kInvariant: 699 return "invariant"; 700 case TensorExp::Kind::kLoopVar: 701 return "index"; 702 case TensorExp::Kind::kSynZero: 703 return "0"; 704 // Unary operations. 705 case TensorExp::Kind::kAbsF: 706 case TensorExp::Kind::kAbsC: 707 case TensorExp::Kind::kAbsI: 708 return "abs"; 709 case TensorExp::Kind::kCeilF: 710 return "ceil"; 711 case TensorExp::Kind::kFloorF: 712 return "floor"; 713 case TensorExp::Kind::kSqrtF: 714 case TensorExp::Kind::kSqrtC: 715 return "sqrt"; 716 case TensorExp::Kind::kExpm1F: 717 case TensorExp::Kind::kExpm1C: 718 return "expm1"; 719 case TensorExp::Kind::kLog1pF: 720 case TensorExp::Kind::kLog1pC: 721 return "log1p"; 722 case TensorExp::Kind::kSinF: 723 case TensorExp::Kind::kSinC: 724 return "sin"; 725 case TensorExp::Kind::kTanhF: 726 case TensorExp::Kind::kTanhC: 727 return "tanh"; 728 case TensorExp::Kind::kNegF: 729 case TensorExp::Kind::kNegC: 730 case TensorExp::Kind::kNegI: 731 return "-"; 732 case TensorExp::Kind::kTruncF: 733 case TensorExp::Kind::kExtF: 734 case TensorExp::Kind::kCastFS: 735 case TensorExp::Kind::kCastFU: 736 case TensorExp::Kind::kCastSF: 737 case TensorExp::Kind::kCastUF: 738 case TensorExp::Kind::kCastS: 739 case TensorExp::Kind::kCastU: 740 case TensorExp::Kind::kCastIdx: 741 case TensorExp::Kind::kTruncI: 742 case TensorExp::Kind::kCIm: 743 return "complex.im"; 744 case TensorExp::Kind::kCRe: 745 return "complex.re"; 746 case TensorExp::Kind::kBitCast: 747 return "cast"; 748 case TensorExp::Kind::kBinaryBranch: 749 return "binary_branch"; 750 case TensorExp::Kind::kUnary: 751 return "unary"; 752 case TensorExp::Kind::kSelect: 753 return "select"; 754 // Binary operations. 755 case TensorExp::Kind::kMulF: 756 case TensorExp::Kind::kMulC: 757 case TensorExp::Kind::kMulI: 758 return "*"; 759 case TensorExp::Kind::kDivF: 760 case TensorExp::Kind::kDivC: 761 case TensorExp::Kind::kDivS: 762 case TensorExp::Kind::kDivU: 763 return "/"; 764 case TensorExp::Kind::kAddF: 765 case TensorExp::Kind::kAddC: 766 case TensorExp::Kind::kAddI: 767 return "+"; 768 case TensorExp::Kind::kSubF: 769 case TensorExp::Kind::kSubC: 770 case TensorExp::Kind::kSubI: 771 return "-"; 772 case TensorExp::Kind::kAndI: 773 return "&"; 774 case TensorExp::Kind::kOrI: 775 return "|"; 776 case TensorExp::Kind::kXorI: 777 return "^"; 778 case TensorExp::Kind::kShrS: 779 return "a>>"; 780 case TensorExp::Kind::kShrU: 781 return ">>"; 782 case TensorExp::Kind::kShlI: 783 return "<<"; 784 case TensorExp::Kind::kCmpF: 785 case TensorExp::Kind::kCmpI: 786 return "cmp"; 787 case TensorExp::Kind::kBinary: 788 return "binary"; 789 case TensorExp::Kind::kReduce: 790 return "reduce"; 791 case TensorExp::Kind::kDenseOp: 792 return "dense"; 793 } 794 llvm_unreachable("unexpected kind for symbol"); 795 } 796 797 void Merger::dumpExp(ExprId e) const { 798 const auto &expr = exp(e); 799 switch (expr.kind) { 800 // Leaf. 801 case TensorExp::Kind::kTensor: 802 if (expr.tensor == syntheticTensor) 803 llvm::dbgs() << "synthetic_"; 804 else if (expr.tensor == outTensor) 805 llvm::dbgs() << "output_"; 806 llvm::dbgs() << "tensor_" << expr.tensor; 807 break; 808 case TensorExp::Kind::kInvariant: 809 llvm::dbgs() << "invariant"; 810 break; 811 case TensorExp::Kind::kSynZero: 812 llvm::dbgs() << "0"; 813 break; 814 case TensorExp::Kind::kLoopVar: 815 llvm::dbgs() << "loopvar_" << expr.loop; 816 break; 817 // Unary operations. 818 case TensorExp::Kind::kAbsF: 819 case TensorExp::Kind::kAbsC: 820 case TensorExp::Kind::kAbsI: 821 case TensorExp::Kind::kCeilF: 822 case TensorExp::Kind::kFloorF: 823 case TensorExp::Kind::kSqrtF: 824 case TensorExp::Kind::kSqrtC: 825 case TensorExp::Kind::kExpm1F: 826 case TensorExp::Kind::kExpm1C: 827 case TensorExp::Kind::kLog1pF: 828 case TensorExp::Kind::kLog1pC: 829 case TensorExp::Kind::kSinF: 830 case TensorExp::Kind::kSinC: 831 case TensorExp::Kind::kTanhF: 832 case TensorExp::Kind::kTanhC: 833 case TensorExp::Kind::kNegF: 834 case TensorExp::Kind::kNegC: 835 case TensorExp::Kind::kNegI: 836 case TensorExp::Kind::kTruncF: 837 case TensorExp::Kind::kExtF: 838 case TensorExp::Kind::kCastFS: 839 case TensorExp::Kind::kCastFU: 840 case TensorExp::Kind::kCastSF: 841 case TensorExp::Kind::kCastUF: 842 case TensorExp::Kind::kCastS: 843 case TensorExp::Kind::kCastU: 844 case TensorExp::Kind::kCastIdx: 845 case TensorExp::Kind::kTruncI: 846 case TensorExp::Kind::kCIm: 847 case TensorExp::Kind::kCRe: 848 case TensorExp::Kind::kBitCast: 849 case TensorExp::Kind::kBinaryBranch: 850 case TensorExp::Kind::kUnary: 851 case TensorExp::Kind::kSelect: 852 llvm::dbgs() << kindToOpSymbol(expr.kind) << " "; 853 dumpExp(expr.children.e0); 854 break; 855 // Binary operations. 856 case TensorExp::Kind::kMulF: 857 case TensorExp::Kind::kMulC: 858 case TensorExp::Kind::kMulI: 859 case TensorExp::Kind::kDivF: 860 case TensorExp::Kind::kDivC: 861 case TensorExp::Kind::kDivS: 862 case TensorExp::Kind::kDivU: 863 case TensorExp::Kind::kAddF: 864 case TensorExp::Kind::kAddC: 865 case TensorExp::Kind::kAddI: 866 case TensorExp::Kind::kSubF: 867 case TensorExp::Kind::kSubC: 868 case TensorExp::Kind::kSubI: 869 case TensorExp::Kind::kAndI: 870 case TensorExp::Kind::kOrI: 871 case TensorExp::Kind::kXorI: 872 case TensorExp::Kind::kShrS: 873 case TensorExp::Kind::kShrU: 874 case TensorExp::Kind::kShlI: 875 case TensorExp::Kind::kCmpF: 876 case TensorExp::Kind::kCmpI: 877 case TensorExp::Kind::kBinary: 878 case TensorExp::Kind::kReduce: 879 case TensorExp::Kind::kDenseOp: 880 llvm::dbgs() << "("; 881 dumpExp(expr.children.e0); 882 llvm::dbgs() << " " << kindToOpSymbol(expr.kind); 883 if (expr.attr) 884 llvm::dbgs() << "{" << expr.attr << "}"; 885 if (expr.children.e1 != detail::kInvalidId) { 886 llvm::dbgs() << " "; 887 dumpExp(expr.children.e1); 888 llvm::dbgs() << ")"; 889 } else { 890 assert(expr.kind == TensorExp::Kind::kDenseOp); 891 } 892 break; 893 } 894 } 895 896 void Merger::dumpLat(LatPointId p) const { 897 const auto &point = lat(p); 898 llvm::dbgs() << "lat("; 899 dumpBits(point.bits); 900 llvm::dbgs() << " :"; 901 dumpBits(point.simple); 902 llvm::dbgs() << " : "; 903 dumpExp(point.exp); 904 llvm::dbgs() << " )\n"; 905 } 906 907 void Merger::dumpSet(LatSetId s) const { 908 const auto &ss = set(s); 909 llvm::dbgs() << "{ #" << ss.size() << "\n"; 910 for (const LatPointId p : ss) { 911 llvm::dbgs() << " "; 912 dumpLat(p); 913 } 914 llvm::dbgs() << "}\n"; 915 } 916 917 void Merger::dumpBits(const BitVector &bits) const { 918 for (TensorLoopId b = 0, be = bits.size(); b < be; b++) { 919 if (bits[b]) { 920 const TensorId t = tensor(b); 921 const LoopId i = loop(b); 922 const auto dlt = lvlTypes[t][i]; 923 if (isLvlWithNonTrivialIdxExp(b)) 924 llvm::dbgs() << " DEP_" << t << "_" << i; 925 else 926 llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); 927 } 928 } 929 } 930 931 #endif // NDEBUG 932 933 //===----------------------------------------------------------------------===// 934 // Builder methods. 935 //===----------------------------------------------------------------------===// 936 937 LatSetId Merger::buildLattices(ExprId e, LoopId i) { 938 // NOTE: The `expr` reference will be invalidated by recursive calls 939 // (and any other method that may add new expressions); therefore, the 940 // code below must make sure to copy fields of `expr` into local variables 941 // before making any recursive calls. 942 const auto &expr = exp(e); 943 const TensorExp::Kind kind = expr.kind; 944 switch (kind) { 945 // Leaf. 946 case TensorExp::Kind::kTensor: 947 case TensorExp::Kind::kInvariant: 948 case TensorExp::Kind::kSynZero: 949 case TensorExp::Kind::kLoopVar: { 950 // Either the loop-var is really used in the tensor expression, or it is 951 // set to the undefined loop-var in that level. An invariant expression, 952 // a proper index value, and a truly dynamic sparse output tensor are set 953 // to a synthetic tensor with undefined indices only to ensure the 954 // iteration space is not skipped as a result of their contents. 955 const LatSetId s = addSet(); 956 TensorId t = syntheticTensor; 957 if (kind == TensorExp::Kind::kTensor) { 958 t = expr.tensor; 959 if (hasSparseOut && t == outTensor) 960 t = syntheticTensor; 961 } 962 latSets[s].push_back(addLat(t, i, e)); 963 return s; 964 } 965 // Unary operations. 966 case TensorExp::Kind::kAbsF: 967 case TensorExp::Kind::kAbsC: 968 case TensorExp::Kind::kAbsI: 969 case TensorExp::Kind::kCeilF: 970 case TensorExp::Kind::kFloorF: 971 case TensorExp::Kind::kSqrtF: 972 case TensorExp::Kind::kSqrtC: 973 case TensorExp::Kind::kExpm1F: 974 case TensorExp::Kind::kExpm1C: 975 case TensorExp::Kind::kLog1pF: 976 case TensorExp::Kind::kLog1pC: 977 case TensorExp::Kind::kSinF: 978 case TensorExp::Kind::kSinC: 979 case TensorExp::Kind::kTanhF: 980 case TensorExp::Kind::kTanhC: 981 case TensorExp::Kind::kNegF: 982 case TensorExp::Kind::kNegC: 983 case TensorExp::Kind::kNegI: 984 case TensorExp::Kind::kTruncF: 985 case TensorExp::Kind::kExtF: 986 case TensorExp::Kind::kCastFS: 987 case TensorExp::Kind::kCastFU: 988 case TensorExp::Kind::kCastSF: 989 case TensorExp::Kind::kCastUF: 990 case TensorExp::Kind::kCastS: 991 case TensorExp::Kind::kCastU: 992 case TensorExp::Kind::kCastIdx: 993 case TensorExp::Kind::kTruncI: 994 case TensorExp::Kind::kCIm: 995 case TensorExp::Kind::kCRe: 996 case TensorExp::Kind::kBitCast: 997 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 998 // lattice set of the operand through the operator into a new set. 999 // 1000 // -y|!y | y | 1001 // --+---+---+ 1002 // | 0 |-y | 1003 { 1004 const ExprId e0 = expr.children.e0; 1005 const Value v = expr.val; 1006 return mapSet(kind, buildLattices(e0, i), v); 1007 } 1008 case TensorExp::Kind::kBinaryBranch: 1009 case TensorExp::Kind::kSelect: 1010 // The left or right half of a binary operation which has already 1011 // been split into separate operations for each region. 1012 { 1013 const ExprId e0 = expr.children.e0; 1014 Operation *const op = expr.op; 1015 return mapSet(kind, buildLattices(e0, i), Value(), op); 1016 } 1017 case TensorExp::Kind::kUnary: 1018 // A custom unary operation. 1019 // 1020 // op y| !y | y | 1021 // ----+----------+------------+ 1022 // | absent() | present(y) | 1023 { 1024 const ExprId e0 = expr.children.e0; 1025 UnaryOp unop = cast<UnaryOp>(expr.op); 1026 const LatSetId child0 = buildLattices(e0, i); 1027 Region &absentRegion = unop.getAbsentRegion(); 1028 if (absentRegion.empty()) { 1029 // Simple mapping over existing values. 1030 return mapSet(kind, child0, Value(), unop); 1031 } 1032 // Use a disjunction with `unop` on the left and the absent value as an 1033 // invariant on the right. 1034 Block &absentBlock = absentRegion.front(); 1035 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); 1036 const Value absentVal = absentYield.getResult(); 1037 const ExprId rhs = addInvariantExp(absentVal); 1038 return disjSet(e, child0, buildLattices(rhs, i), unop); 1039 } 1040 // Binary operations. 1041 case TensorExp::Kind::kMulF: 1042 case TensorExp::Kind::kMulC: 1043 case TensorExp::Kind::kMulI: 1044 case TensorExp::Kind::kAndI: 1045 // A multiplicative operation only needs to be performed 1046 // for the conjunction of sparse iteration spaces. 1047 // 1048 // x*y|!y | y | 1049 // ---+---+---+ 1050 // !x | 0 | 0 | 1051 // x | 0 |x*y| 1052 // 1053 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. 1054 { 1055 const ExprId e0 = expr.children.e0; 1056 const ExprId e1 = expr.children.e1; 1057 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1058 } 1059 case TensorExp::Kind::kDivF: 1060 case TensorExp::Kind::kDivC: 1061 case TensorExp::Kind::kDivS: 1062 case TensorExp::Kind::kDivU: 1063 // A division is tricky, since 0/0, 0/c, c/0 all have 1064 // specific outcomes for floating-point and integers. 1065 // Thus, we need to traverse the full iteration space. 1066 // 1067 // x/y|!y | y | 1068 // ---+---+---+ 1069 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 1070 // x |x/0|x/y| INT: x/0=exception for any x 1071 // 1072 // TODO: for now we "fixed" this by only accepting x/c cases 1073 // during expression building, so that the conjunction 1074 // rules applies (viz. x/c = x*(1/c) as far as lattice 1075 // construction is concerned). 1076 { 1077 const ExprId e0 = expr.children.e0; 1078 const ExprId e1 = expr.children.e1; 1079 assert(!maybeZero(e1)); 1080 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1081 } 1082 case TensorExp::Kind::kAddF: 1083 case TensorExp::Kind::kAddC: 1084 case TensorExp::Kind::kAddI: 1085 case TensorExp::Kind::kSubF: 1086 case TensorExp::Kind::kSubC: 1087 case TensorExp::Kind::kSubI: 1088 case TensorExp::Kind::kOrI: 1089 case TensorExp::Kind::kXorI: 1090 // An additive operation needs to be performed 1091 // for the disjunction of sparse iteration spaces. 1092 // 1093 // x+y|!y | y | x-y|!y | y | 1094 // ---+---+---+ ---+---+---+ 1095 // !x | 0 | y | !x | 0 |-y | 1096 // x | x |x+y| x | x |x-y| 1097 { 1098 const ExprId e0 = expr.children.e0; 1099 const ExprId e1 = expr.children.e1; 1100 return disjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1101 } 1102 case TensorExp::Kind::kCmpF: 1103 case TensorExp::Kind::kCmpI: 1104 // A comparison operation needs to be performed 1105 // for the disjunction of sparse iteration spaces. 1106 // 1107 // x < y | !y | y | 1108 // -------+-------+-------+ 1109 // !x | 0 | 0 < y | 1110 // x | x < 0 | x < y | 1111 { 1112 const ExprId e0 = expr.children.e0; 1113 const ExprId e1 = expr.children.e1; 1114 return disjSetWithZero(e, buildLattices(e0, i), buildLattices(e1, i)); 1115 } 1116 case TensorExp::Kind::kShrS: 1117 case TensorExp::Kind::kShrU: 1118 case TensorExp::Kind::kShlI: 1119 // A shift operation by an invariant amount (viz. tensor expressions 1120 // can only occur at the left-hand-side of the operator) can be handled 1121 // with the conjunction rule. 1122 { 1123 const ExprId e0 = expr.children.e0; 1124 const ExprId e1 = expr.children.e1; 1125 assert(isInvariant(e1)); 1126 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1127 } 1128 case TensorExp::Kind::kBinary: 1129 // A custom binary operation. 1130 // 1131 // x op y| !y | y | 1132 // ------+---------+--------------+ 1133 // !x | empty | right(y) | 1134 // x | left(x) | overlap(x,y) | 1135 { 1136 const ExprId e0 = expr.children.e0; 1137 const ExprId e1 = expr.children.e1; 1138 BinaryOp binop = cast<BinaryOp>(expr.op); 1139 const LatSetId child0 = buildLattices(e0, i); 1140 const LatSetId child1 = buildLattices(e1, i); 1141 Region &leftRegion = binop.getLeftRegion(); 1142 Region &rightRegion = binop.getRightRegion(); 1143 // Left Region. 1144 Operation *leftYield = nullptr; 1145 if (!leftRegion.empty()) { 1146 Block &leftBlock = leftRegion.front(); 1147 leftYield = leftBlock.getTerminator(); 1148 } 1149 // Right Region. 1150 Operation *rightYield = nullptr; 1151 if (!rightRegion.empty()) { 1152 Block &rightBlock = rightRegion.front(); 1153 rightYield = rightBlock.getTerminator(); 1154 } 1155 bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty(); 1156 bool includeRight = binop.getRightIdentity() || !rightRegion.empty(); 1157 return combiSet(e, child0, child1, binop, includeLeft, 1158 TensorExp::Kind::kBinaryBranch, leftYield, includeRight, 1159 TensorExp::Kind::kBinaryBranch, rightYield); 1160 } 1161 case TensorExp::Kind::kReduce: 1162 // A custom reduce operation. 1163 { 1164 const ExprId e0 = expr.children.e0; 1165 const ExprId e1 = expr.children.e1; 1166 Operation *const op = expr.op; 1167 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op); 1168 } 1169 case TensorExp::Kind::kDenseOp: { 1170 // It does not really matter whether we use conjunctive/disjunctive set 1171 // here, as all the operands of kDenseOp must be dense, the disjunctive set 1172 // will be optimized into conjunctive set eventually. 1173 if (expr.children.e1 == detail::kInvalidId) { 1174 const ExprId e0 = expr.children.e0; 1175 Operation *const op = expr.op; 1176 return mapSet(kind, buildLattices(e0, i), Value(), op); 1177 } 1178 1179 const ExprId e0 = expr.children.e0; 1180 const ExprId e1 = expr.children.e1; 1181 Operation *const op = expr.op; 1182 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op); 1183 } 1184 } 1185 llvm_unreachable("unexpected expression kind"); 1186 } 1187 1188 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 1189 // Build the linalg semantics backward from yield. 1190 Operation *yield = op.getRegion().front().getTerminator(); 1191 assert(isa<linalg::YieldOp>(yield)); 1192 return buildTensorExp(op, yield->getOperand(0)).first; 1193 } 1194 1195 /// Only returns false if we are certain this is a nonzero. 1196 bool Merger::maybeZero(ExprId e) const { 1197 const auto &expr = exp(e); 1198 if (expr.kind == TensorExp::Kind::kInvariant) { 1199 if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) { 1200 ArrayAttr arrayAttr = c.getValue(); 1201 return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() && 1202 cast<FloatAttr>(arrayAttr[1]).getValue().isZero(); 1203 } 1204 if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>()) 1205 return c.value() == 0; 1206 if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>()) 1207 return c.value().isZero(); 1208 } 1209 return true; 1210 } 1211 1212 Type Merger::inferType(ExprId e, Value src) const { 1213 // Obtain the destination type from the cast node. 1214 Type dtp = exp(e).val.getType(); 1215 // Inspect source type. For vector types, apply the same 1216 // vectorization to the destination type. 1217 if (auto vtp = dyn_cast<VectorType>(src.getType())) 1218 return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims()); 1219 return dtp; 1220 } 1221 1222 /// Ensures that sparse compiler can generate code for expression. 1223 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) { 1224 // Arguments are always admissible. 1225 if (isa<BlockArgument>(v)) 1226 return true; 1227 // Accept index anywhere. 1228 Operation *def = v.getDefiningOp(); 1229 if (isa<linalg::IndexOp>(def)) 1230 return true; 1231 // Operation defined outside branch. 1232 if (def->getBlock() != block) 1233 return def->getBlock() != op->getBlock(); // invariant? 1234 // Operation defined within branch. Anything is accepted, 1235 // as long as all subexpressions are admissible. 1236 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) 1237 if (!isAdmissibleBranchExp(op, block, def->getOperand(i))) 1238 return false; 1239 return true; 1240 } 1241 1242 /// Ensures that sparse compiler can generate code for branch. 1243 static bool isAdmissibleBranch(Operation *op, Region ®ion) { 1244 if (region.empty()) 1245 return true; 1246 // Build the semi-ring branch semantics backward from yield. 1247 Operation *yield = region.front().getTerminator(); 1248 assert(isa<YieldOp>(yield)); 1249 return isAdmissibleBranchExp(op, ®ion.front(), yield->getOperand(0)); 1250 } 1251 1252 std::pair<std::optional<ExprId>, bool> 1253 Merger::buildTensorExp(linalg::GenericOp op, Value v) { 1254 // Recursion leaves. 1255 if (auto arg = dyn_cast<BlockArgument>(v)) { 1256 const TensorId tid = makeTensorId(arg.getArgNumber()); 1257 // Any argument of the generic op that is not marked as a scalar 1258 // argument is considered a tensor, indexed by the implicit loop 1259 // bounds. This includes rank-0 tensor arguments. 1260 if (arg.getOwner()->getParentOp() == op) { 1261 OpOperand &t = op->getOpOperand(tid); 1262 bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr; 1263 if (!op.isScalar(&t)) 1264 return {addTensorExp(tid), hasSpDep}; 1265 v = t.get(); // get scalar value 1266 } 1267 // Any other argument (marked as scalar argument for the generic op 1268 // or belonging to an enveloping op) is considered invariant. 1269 return {addInvariantExp(v), /*hasSpDep=*/false}; 1270 } 1271 // Something defined outside is invariant. 1272 Operation *def = v.getDefiningOp(); 1273 if (def->getBlock() != &op.getRegion().front()) 1274 return {addInvariantExp(v), /*hasSpDep=*/false}; 1275 // Construct index operations. 1276 if (def->getNumOperands() == 0) { 1277 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 1278 return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false}; 1279 } 1280 1281 // Construct unary operations if subexpression can be built. 1282 if (def->getNumOperands() == 1) { 1283 const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0)); 1284 if (x.has_value()) { 1285 const ExprId e = *x; 1286 if (isa<math::AbsFOp>(def)) 1287 return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep}; 1288 if (isa<complex::AbsOp>(def)) 1289 return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep}; 1290 if (isa<math::AbsIOp>(def)) 1291 return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep}; 1292 if (isa<math::CeilOp>(def)) 1293 return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep}; 1294 if (isa<math::FloorOp>(def)) 1295 return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep}; 1296 if (isa<math::SqrtOp>(def)) 1297 return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep}; 1298 if (isa<complex::SqrtOp>(def)) 1299 return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep}; 1300 if (isa<math::ExpM1Op>(def)) 1301 return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep}; 1302 if (isa<complex::Expm1Op>(def)) 1303 return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep}; 1304 if (isa<math::Log1pOp>(def)) 1305 return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep}; 1306 if (isa<complex::Log1pOp>(def)) 1307 return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep}; 1308 if (isa<math::SinOp>(def)) 1309 return {addExp(TensorExp::Kind::kSinF, e), hasSpDep}; 1310 if (isa<complex::SinOp>(def)) 1311 return {addExp(TensorExp::Kind::kSinC, e), hasSpDep}; 1312 if (isa<math::TanhOp>(def)) 1313 return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep}; 1314 if (isa<complex::TanhOp>(def)) 1315 return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep}; 1316 if (isa<arith::NegFOp>(def)) 1317 return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std 1318 if (isa<complex::NegOp>(def)) 1319 return {addExp(TensorExp::Kind::kNegC, e), hasSpDep}; 1320 if (isa<arith::TruncFOp>(def)) 1321 return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep}; 1322 if (isa<arith::ExtFOp>(def)) 1323 return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep}; 1324 if (isa<arith::FPToSIOp>(def)) 1325 return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep}; 1326 if (isa<arith::FPToUIOp>(def)) 1327 return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep}; 1328 if (isa<arith::SIToFPOp>(def)) 1329 return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep}; 1330 if (isa<arith::UIToFPOp>(def)) 1331 return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep}; 1332 if (isa<arith::ExtSIOp>(def)) 1333 return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep}; 1334 if (isa<arith::ExtUIOp>(def)) 1335 return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep}; 1336 if (isa<arith::IndexCastOp>(def)) 1337 return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep}; 1338 if (isa<arith::TruncIOp>(def)) 1339 return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep}; 1340 if (isa<complex::ImOp>(def)) 1341 return {addExp(TensorExp::Kind::kCIm, e), hasSpDep}; 1342 if (isa<complex::ReOp>(def)) 1343 return {addExp(TensorExp::Kind::kCRe, e), hasSpDep}; 1344 if (isa<arith::BitcastOp>(def)) 1345 return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep}; 1346 if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) { 1347 if (isAdmissibleBranch(unop, unop.getPresentRegion()) && 1348 isAdmissibleBranch(unop, unop.getAbsentRegion())) 1349 return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep}; 1350 } 1351 if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) { 1352 if (isAdmissibleBranch(selop, selop.getRegion())) 1353 return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep}; 1354 } 1355 } 1356 } 1357 // Construct binary operations if subexpressions can be built. 1358 // See buildLattices() for an explanation of rejecting certain 1359 // division and shift operations. 1360 if (def->getNumOperands() == 2) { 1361 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0)); 1362 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1)); 1363 bool hasSpDep = xDepSp || yDepSp; 1364 if (x.has_value() && y.has_value()) { 1365 const ExprId e0 = *x; 1366 const ExprId e1 = *y; 1367 if (isa<arith::MulFOp>(def)) 1368 return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep}; 1369 if (isa<complex::MulOp>(def)) 1370 return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep}; 1371 if (isa<arith::MulIOp>(def)) 1372 return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep}; 1373 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 1374 return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep}; 1375 if (isa<complex::DivOp>(def) && !maybeZero(e1)) 1376 return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep}; 1377 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 1378 return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep}; 1379 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 1380 return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep}; 1381 if (isa<arith::AddFOp>(def)) 1382 return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep}; 1383 if (isa<complex::AddOp>(def)) 1384 return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep}; 1385 if (isa<arith::AddIOp>(def)) 1386 return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep}; 1387 if (isa<arith::SubFOp>(def)) 1388 return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep}; 1389 if (isa<complex::SubOp>(def)) 1390 return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep}; 1391 if (isa<arith::SubIOp>(def)) 1392 return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep}; 1393 if (isa<arith::AndIOp>(def)) 1394 return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep}; 1395 if (isa<arith::OrIOp>(def)) 1396 return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep}; 1397 if (isa<arith::XOrIOp>(def)) 1398 return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep}; 1399 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 1400 return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep}; 1401 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 1402 return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep}; 1403 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 1404 return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep}; 1405 if (auto ci = dyn_cast<arith::CmpIOp>(def)) { 1406 if (ci.getPredicate() == arith::CmpIPredicate::eq && 1407 ci.getPredicate() == arith::CmpIPredicate::sle && 1408 ci.getPredicate() == arith::CmpIPredicate::sge && 1409 ci.getPredicate() == arith::CmpIPredicate::ule && 1410 ci.getPredicate() == arith::CmpIPredicate::uge) { 1411 // We can not sparsify comparison with equal, this is because 0 <= 0 1412 // yields true, and thus densifies the result. 1413 return {std::nullopt, false}; 1414 } 1415 1416 auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr, 1417 ci.getPredicateAttr()); 1418 return {e, hasSpDep}; 1419 } 1420 if (auto cf = dyn_cast<arith::CmpFOp>(def)) { 1421 if (cf.getPredicate() == arith::CmpFPredicate::OEQ && 1422 cf.getPredicate() == arith::CmpFPredicate::OGE && 1423 cf.getPredicate() == arith::CmpFPredicate::OLE && 1424 cf.getPredicate() == arith::CmpFPredicate::ONE && 1425 cf.getPredicate() == arith::CmpFPredicate::UEQ && 1426 cf.getPredicate() == arith::CmpFPredicate::UGE && 1427 cf.getPredicate() == arith::CmpFPredicate::ULE && 1428 cf.getPredicate() == arith::CmpFPredicate::ORD && 1429 cf.getPredicate() == arith::CmpFPredicate::UNO) { 1430 // We can not sparsify comparison with equal, this is because 0 <= 0 1431 // yields true, and thus densifies the result. 1432 return {std::nullopt, false}; 1433 } 1434 auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr, 1435 cf.getPredicateAttr()); 1436 return {e, hasSpDep}; 1437 } 1438 if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) { 1439 if (isAdmissibleBranch(binop, binop.getOverlapRegion()) && 1440 (binop.getLeftIdentity() || 1441 isAdmissibleBranch(binop, binop.getLeftRegion())) && 1442 (binop.getRightIdentity() || 1443 isAdmissibleBranch(binop, binop.getRightRegion()))) 1444 return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep}; 1445 } 1446 } 1447 } 1448 // Construct ternary operations if subexpressions can be built. 1449 if (def->getNumOperands() == 3) { 1450 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0)); 1451 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1)); 1452 const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2)); 1453 bool hasSpDep = xDepSp || yDepSp || zDepSp; 1454 if (x.has_value() && y.has_value() && z.has_value()) { 1455 const ExprId e0 = *x; 1456 const ExprId e1 = *y; 1457 if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) { 1458 if (isAdmissibleBranch(redop, redop.getRegion())) 1459 return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep}; 1460 } 1461 } 1462 } 1463 1464 // If we reach here, we are dealing with an operation that is not currently 1465 // sparsifiable. We can still generate code for it if all its operands only 1466 // have dense dependencies (i.e., all the values are loaded from dense 1467 // tensors). 1468 if (def->getNumResults() != 1) // only handle single result operation. 1469 return {std::nullopt, false}; 1470 1471 SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp; 1472 // Builds all the sub-expressions 1473 for (Value operand : def->getOperands()) 1474 subExp.push_back(buildTensorExp(op, operand)); 1475 1476 if (llvm::all_of(subExp, 1477 [](auto e) { return e.first.has_value() && !e.second; })) { 1478 // All the subexpressions can be built and has *no* sparse dependencies. 1479 if (subExp.size() == 2) { 1480 auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first, 1481 *subExp[1].first, def); 1482 return {e, false}; 1483 } 1484 if (subExp.size() == 1) { 1485 auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first, 1486 detail::kInvalidId, def); 1487 return {e, false}; 1488 } 1489 } 1490 // Cannot build. 1491 return {std::nullopt, false}; 1492 } 1493 1494 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, 1495 ValueRange vals) { 1496 // Make a clone of overlap region. 1497 Region tmpRegion; 1498 IRMapping mapper; 1499 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); 1500 Block &clonedBlock = tmpRegion.front(); 1501 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); 1502 // Merge cloned block and return yield value. 1503 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1504 rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals); 1505 Value val = clonedYield.getResult(); 1506 rewriter.eraseOp(clonedYield); 1507 rewriter.eraseOp(placeholder); 1508 return val; 1509 } 1510 1511 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, 1512 Operation *op, Value v0) { 1513 if (!v0) 1514 // Empty input value must be propagated. 1515 return Value(); 1516 UnaryOp unop = cast<UnaryOp>(op); 1517 Region &presentRegion = unop.getPresentRegion(); 1518 if (presentRegion.empty()) 1519 // Uninitialized Value() will be interpreted as missing data in the 1520 // output. 1521 return Value(); 1522 return insertYieldOp(rewriter, loc, presentRegion, {v0}); 1523 } 1524 1525 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, 1526 Operation *op, Value v0, Value v1) { 1527 if (!v0 || !v1) 1528 // Empty input values must be propagated. 1529 return Value(); 1530 BinaryOp binop = cast<BinaryOp>(op); 1531 Region &overlapRegion = binop.getOverlapRegion(); 1532 if (overlapRegion.empty()) 1533 // Uninitialized Value() will be interpreted as missing data in the 1534 // output. 1535 return Value(); 1536 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); 1537 } 1538 1539 Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, 1540 Value v1) const { 1541 const auto &expr = exp(e); 1542 switch (expr.kind) { 1543 // Leaf. 1544 case TensorExp::Kind::kTensor: 1545 case TensorExp::Kind::kInvariant: 1546 case TensorExp::Kind::kLoopVar: 1547 case TensorExp::Kind::kSynZero: 1548 llvm_unreachable("unexpected non-op"); 1549 // Unary operations. 1550 case TensorExp::Kind::kAbsF: 1551 return rewriter.create<math::AbsFOp>(loc, v0); 1552 case TensorExp::Kind::kAbsC: { 1553 auto type = cast<ComplexType>(v0.getType()); 1554 auto eltType = cast<FloatType>(type.getElementType()); 1555 return rewriter.create<complex::AbsOp>(loc, eltType, v0); 1556 } 1557 case TensorExp::Kind::kAbsI: 1558 return rewriter.create<math::AbsIOp>(loc, v0); 1559 case TensorExp::Kind::kCeilF: 1560 return rewriter.create<math::CeilOp>(loc, v0); 1561 case TensorExp::Kind::kFloorF: 1562 return rewriter.create<math::FloorOp>(loc, v0); 1563 case TensorExp::Kind::kSqrtF: 1564 return rewriter.create<math::SqrtOp>(loc, v0); 1565 case TensorExp::Kind::kSqrtC: 1566 return rewriter.create<complex::SqrtOp>(loc, v0); 1567 case TensorExp::Kind::kExpm1F: 1568 return rewriter.create<math::ExpM1Op>(loc, v0); 1569 case TensorExp::Kind::kExpm1C: 1570 return rewriter.create<complex::Expm1Op>(loc, v0); 1571 case TensorExp::Kind::kLog1pF: 1572 return rewriter.create<math::Log1pOp>(loc, v0); 1573 case TensorExp::Kind::kLog1pC: 1574 return rewriter.create<complex::Log1pOp>(loc, v0); 1575 case TensorExp::Kind::kSinF: 1576 return rewriter.create<math::SinOp>(loc, v0); 1577 case TensorExp::Kind::kSinC: 1578 return rewriter.create<complex::SinOp>(loc, v0); 1579 case TensorExp::Kind::kTanhF: 1580 return rewriter.create<math::TanhOp>(loc, v0); 1581 case TensorExp::Kind::kTanhC: 1582 return rewriter.create<complex::TanhOp>(loc, v0); 1583 case TensorExp::Kind::kNegF: 1584 return rewriter.create<arith::NegFOp>(loc, v0); 1585 case TensorExp::Kind::kNegC: 1586 return rewriter.create<complex::NegOp>(loc, v0); 1587 case TensorExp::Kind::kNegI: // no negi in std 1588 return rewriter.create<arith::SubIOp>( 1589 loc, 1590 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 1591 rewriter.getZeroAttr(v0.getType())), 1592 v0); 1593 case TensorExp::Kind::kTruncF: 1594 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 1595 case TensorExp::Kind::kExtF: 1596 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 1597 case TensorExp::Kind::kCastFS: 1598 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 1599 case TensorExp::Kind::kCastFU: 1600 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 1601 case TensorExp::Kind::kCastSF: 1602 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 1603 case TensorExp::Kind::kCastUF: 1604 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 1605 case TensorExp::Kind::kCastS: 1606 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 1607 case TensorExp::Kind::kCastU: 1608 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 1609 case TensorExp::Kind::kCastIdx: 1610 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 1611 case TensorExp::Kind::kTruncI: 1612 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 1613 case TensorExp::Kind::kCIm: { 1614 auto type = cast<ComplexType>(v0.getType()); 1615 auto eltType = cast<FloatType>(type.getElementType()); 1616 return rewriter.create<complex::ImOp>(loc, eltType, v0); 1617 } 1618 case TensorExp::Kind::kCRe: { 1619 auto type = cast<ComplexType>(v0.getType()); 1620 auto eltType = cast<FloatType>(type.getElementType()); 1621 return rewriter.create<complex::ReOp>(loc, eltType, v0); 1622 } 1623 case TensorExp::Kind::kBitCast: 1624 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 1625 // Binary operations. 1626 case TensorExp::Kind::kMulF: 1627 return rewriter.create<arith::MulFOp>(loc, v0, v1); 1628 case TensorExp::Kind::kMulC: 1629 return rewriter.create<complex::MulOp>(loc, v0, v1); 1630 case TensorExp::Kind::kMulI: 1631 return rewriter.create<arith::MulIOp>(loc, v0, v1); 1632 case TensorExp::Kind::kDivF: 1633 return rewriter.create<arith::DivFOp>(loc, v0, v1); 1634 case TensorExp::Kind::kDivC: 1635 return rewriter.create<complex::DivOp>(loc, v0, v1); 1636 case TensorExp::Kind::kDivS: 1637 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 1638 case TensorExp::Kind::kDivU: 1639 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 1640 case TensorExp::Kind::kAddF: 1641 return rewriter.create<arith::AddFOp>(loc, v0, v1); 1642 case TensorExp::Kind::kAddC: 1643 return rewriter.create<complex::AddOp>(loc, v0, v1); 1644 case TensorExp::Kind::kAddI: 1645 return rewriter.create<arith::AddIOp>(loc, v0, v1); 1646 case TensorExp::Kind::kSubF: 1647 return rewriter.create<arith::SubFOp>(loc, v0, v1); 1648 case TensorExp::Kind::kSubC: 1649 return rewriter.create<complex::SubOp>(loc, v0, v1); 1650 case TensorExp::Kind::kSubI: 1651 return rewriter.create<arith::SubIOp>(loc, v0, v1); 1652 case TensorExp::Kind::kAndI: 1653 return rewriter.create<arith::AndIOp>(loc, v0, v1); 1654 case TensorExp::Kind::kOrI: 1655 return rewriter.create<arith::OrIOp>(loc, v0, v1); 1656 case TensorExp::Kind::kXorI: 1657 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 1658 case TensorExp::Kind::kShrS: 1659 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 1660 case TensorExp::Kind::kShrU: 1661 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 1662 case TensorExp::Kind::kShlI: 1663 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 1664 case TensorExp::Kind::kCmpI: { 1665 auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr); 1666 return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1); 1667 } 1668 case TensorExp::Kind::kCmpF: { 1669 auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr); 1670 return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1); 1671 } 1672 case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic. 1673 return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(), 1674 {v0}); 1675 case TensorExp::Kind::kUnary: 1676 return buildUnaryPresent(rewriter, loc, expr.op, v0); 1677 case TensorExp::Kind::kSelect: 1678 return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(), 1679 {v0}); 1680 case TensorExp::Kind::kBinary: 1681 return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1); 1682 case TensorExp::Kind::kReduce: { 1683 ReduceOp redOp = cast<ReduceOp>(expr.op); 1684 return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1}); 1685 } 1686 case TensorExp::Kind::kDenseOp: { 1687 Operation *actualOp = expr.op; 1688 IRMapping mapping; 1689 mapping.map(actualOp->getOperand(0), v0); 1690 if (actualOp->getNumOperands() == 2) 1691 mapping.map(actualOp->getOperand(1), v1); 1692 return rewriter.clone(*actualOp, mapping)->getResult(0); 1693 } 1694 } 1695 llvm_unreachable("unexpected expression kind in build"); 1696 } 1697 1698 } // namespace sparse_tensor 1699 } // namespace mlir 1700