1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arith/IR/Arith.h" 11 #include "mlir/Dialect/Complex/IR/Complex.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 15 #include "mlir/IR/Operation.h" 16 #include "llvm/Support/Debug.h" 17 #include <optional> 18 19 namespace mlir { 20 namespace sparse_tensor { 21 22 enum class ExpArity { 23 kNullary, 24 kUnary, 25 kBinary, 26 }; 27 28 static ExpArity getExpArity(TensorExp::Kind k) { 29 switch (k) { 30 // Leaf. 31 case TensorExp::Kind::kTensor: 32 case TensorExp::Kind::kInvariant: 33 case TensorExp::Kind::kLoopVar: 34 return ExpArity::kNullary; 35 case TensorExp::Kind::kAbsF: 36 case TensorExp::Kind::kAbsC: 37 case TensorExp::Kind::kAbsI: 38 case TensorExp::Kind::kCeilF: 39 case TensorExp::Kind::kFloorF: 40 case TensorExp::Kind::kSqrtF: 41 case TensorExp::Kind::kSqrtC: 42 case TensorExp::Kind::kExpm1F: 43 case TensorExp::Kind::kExpm1C: 44 case TensorExp::Kind::kLog1pF: 45 case TensorExp::Kind::kLog1pC: 46 case TensorExp::Kind::kSinF: 47 case TensorExp::Kind::kSinC: 48 case TensorExp::Kind::kTanhF: 49 case TensorExp::Kind::kTanhC: 50 case TensorExp::Kind::kTruncF: 51 case TensorExp::Kind::kExtF: 52 case TensorExp::Kind::kCastFS: 53 case TensorExp::Kind::kCastFU: 54 case TensorExp::Kind::kCastSF: 55 case TensorExp::Kind::kCastUF: 56 case TensorExp::Kind::kCastS: 57 case TensorExp::Kind::kCastU: 58 case TensorExp::Kind::kCastIdx: 59 case TensorExp::Kind::kTruncI: 60 case TensorExp::Kind::kCIm: 61 case TensorExp::Kind::kCRe: 62 case TensorExp::Kind::kBitCast: 63 case TensorExp::Kind::kBinaryBranch: 64 case TensorExp::Kind::kUnary: 65 case TensorExp::Kind::kSelect: 66 case TensorExp::Kind::kNegF: 67 case TensorExp::Kind::kNegC: 68 case TensorExp::Kind::kNegI: 69 return ExpArity::kUnary; 70 // Binary operations. 71 case TensorExp::Kind::kDivF: 72 case TensorExp::Kind::kDivC: 73 case TensorExp::Kind::kDivS: 74 case TensorExp::Kind::kDivU: 75 case TensorExp::Kind::kShrS: 76 case TensorExp::Kind::kShrU: 77 case TensorExp::Kind::kShlI: 78 case TensorExp::Kind::kMulF: 79 case TensorExp::Kind::kMulC: 80 case TensorExp::Kind::kMulI: 81 case TensorExp::Kind::kAndI: 82 case TensorExp::Kind::kAddF: 83 case TensorExp::Kind::kAddC: 84 case TensorExp::Kind::kAddI: 85 case TensorExp::Kind::kOrI: 86 case TensorExp::Kind::kXorI: 87 case TensorExp::Kind::kBinary: 88 case TensorExp::Kind::kReduce: 89 case TensorExp::Kind::kSubF: 90 case TensorExp::Kind::kSubC: 91 case TensorExp::Kind::kSubI: 92 return ExpArity::kBinary; 93 } 94 llvm_unreachable("unexpected kind"); 95 } 96 97 //===----------------------------------------------------------------------===// 98 // Constructors. 99 //===----------------------------------------------------------------------===// 100 101 TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v, 102 Operation *o) 103 : kind(k), val(v), op(o) { 104 switch (kind) { 105 // Leaf. 106 case TensorExp::Kind::kTensor: 107 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 108 tensor = x; 109 return; 110 case TensorExp::Kind::kInvariant: 111 assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o); 112 return; 113 case TensorExp::Kind::kLoopVar: 114 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 115 loop = x; 116 return; 117 // Unary operations. 118 case TensorExp::Kind::kAbsF: 119 case TensorExp::Kind::kAbsC: 120 case TensorExp::Kind::kAbsI: 121 case TensorExp::Kind::kCeilF: 122 case TensorExp::Kind::kFloorF: 123 case TensorExp::Kind::kSqrtF: 124 case TensorExp::Kind::kSqrtC: 125 case TensorExp::Kind::kExpm1F: 126 case TensorExp::Kind::kExpm1C: 127 case TensorExp::Kind::kLog1pF: 128 case TensorExp::Kind::kLog1pC: 129 case TensorExp::Kind::kSinF: 130 case TensorExp::Kind::kSinC: 131 case TensorExp::Kind::kTanhF: 132 case TensorExp::Kind::kTanhC: 133 case TensorExp::Kind::kNegF: 134 case TensorExp::Kind::kNegC: 135 case TensorExp::Kind::kNegI: 136 case TensorExp::Kind::kCIm: 137 case TensorExp::Kind::kCRe: 138 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 139 children.e0 = x; 140 children.e1 = y; 141 return; 142 case TensorExp::Kind::kTruncF: 143 case TensorExp::Kind::kExtF: 144 case TensorExp::Kind::kCastFS: 145 case TensorExp::Kind::kCastFU: 146 case TensorExp::Kind::kCastSF: 147 case TensorExp::Kind::kCastUF: 148 case TensorExp::Kind::kCastS: 149 case TensorExp::Kind::kCastU: 150 case TensorExp::Kind::kCastIdx: 151 case TensorExp::Kind::kTruncI: 152 case TensorExp::Kind::kBitCast: 153 assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o); 154 children.e0 = x; 155 children.e1 = y; 156 return; 157 case TensorExp::Kind::kBinaryBranch: 158 case TensorExp::Kind::kSelect: 159 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o); 160 children.e0 = x; 161 children.e1 = y; 162 return; 163 case TensorExp::Kind::kUnary: 164 // No assertion on y can be made, as the branching paths involve both 165 // a unary (`mapSet`) and binary (`disjSet`) pathway. 166 assert(x != detail::kInvalidId && !v && o); 167 children.e0 = x; 168 children.e1 = y; 169 return; 170 // Binary operations. 171 case TensorExp::Kind::kMulF: 172 case TensorExp::Kind::kMulC: 173 case TensorExp::Kind::kMulI: 174 case TensorExp::Kind::kDivF: 175 case TensorExp::Kind::kDivC: 176 case TensorExp::Kind::kDivS: 177 case TensorExp::Kind::kDivU: 178 case TensorExp::Kind::kAddF: 179 case TensorExp::Kind::kAddC: 180 case TensorExp::Kind::kAddI: 181 case TensorExp::Kind::kSubF: 182 case TensorExp::Kind::kSubC: 183 case TensorExp::Kind::kSubI: 184 case TensorExp::Kind::kAndI: 185 case TensorExp::Kind::kOrI: 186 case TensorExp::Kind::kXorI: 187 case TensorExp::Kind::kShrS: 188 case TensorExp::Kind::kShrU: 189 case TensorExp::Kind::kShlI: 190 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); 191 children.e0 = x; 192 children.e1 = y; 193 return; 194 case TensorExp::Kind::kBinary: 195 case TensorExp::Kind::kReduce: 196 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o); 197 children.e0 = x; 198 children.e1 = y; 199 return; 200 } 201 llvm_unreachable("unexpected kind"); 202 } 203 204 Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops, 205 unsigned numFilterLoops, unsigned maxLvlRank) 206 : outTensor(numInputOutputTensors - 1), 207 syntheticTensor(numInputOutputTensors), 208 numTensors(numInputOutputTensors + 1), numNativeLoops(numNativeLoops), 209 numLoops(numNativeLoops + numFilterLoops), hasSparseOut(false), 210 lvlTypes(numTensors, 211 std::vector<DimLevelType>(numLoops, DimLevelType::Undef)), 212 loopToLvl(numTensors, 213 std::vector<std::optional<Level>>(numLoops, std::nullopt)), 214 lvlToLoop(numTensors, 215 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)), 216 loopToDependencies( 217 numLoops, std::vector<std::optional<std::pair<Level, DimLevelType>>>( 218 numTensors, std::nullopt)), 219 levelToDependentLoop(numTensors, std::vector<std::vector<LoopId>>( 220 maxLvlRank, std::vector<LoopId>())), 221 loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {} 222 223 //===----------------------------------------------------------------------===// 224 // Lattice methods. 225 //===----------------------------------------------------------------------===// 226 227 ExprId Merger::addTensorExp(TensorId t) { 228 assert(isValidTensorId(t)); 229 const ExprId eNew(tensorExps.size()); 230 tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId, 231 Value(), nullptr); 232 return eNew; 233 } 234 235 ExprId Merger::addLoopVarExp(LoopId i) { 236 assert(isValidLoopId(i)); 237 const ExprId eNew(tensorExps.size()); 238 tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId, 239 Value(), nullptr); 240 return eNew; 241 } 242 243 ExprId Merger::addInvariantExp(Value v) { 244 const ExprId eNew(tensorExps.size()); 245 tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId, 246 detail::kInvalidId, v, nullptr); 247 return eNew; 248 } 249 250 ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op) { 251 assert(k > TensorExp::Kind::kLoopVar); 252 const ExprId eNew(tensorExps.size()); 253 tensorExps.emplace_back(k, e0, e1, Value(), op); 254 return eNew; 255 } 256 257 ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op) { 258 assert(k > TensorExp::Kind::kLoopVar); 259 const ExprId eNew(tensorExps.size()); 260 tensorExps.emplace_back(k, e, detail::kInvalidId, v, op); 261 return eNew; 262 } 263 264 LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) { 265 const LatPointId pNew(latPoints.size()); 266 const unsigned size = numLoops * numTensors; 267 const TensorLoopId b = makeTensorLoopId(t, i); 268 latPoints.emplace_back(size, e); 269 latPoints[pNew].bits.set(b); 270 return pNew; 271 } 272 273 LatPointId Merger::addLat(const BitVector &bits, ExprId e) { 274 assert(bits.size() == numLoops * numTensors); 275 const LatPointId pNew(latPoints.size()); 276 latPoints.emplace_back(bits, e); 277 return pNew; 278 } 279 280 LatSetId Merger::addSet() { 281 const LatSetId sNew(latSets.size()); 282 latSets.emplace_back(); 283 return sNew; 284 } 285 286 LatPointId Merger::conjLat(TensorExp::Kind kind, LatPointId p0, LatPointId p1, 287 Operation *op) { 288 const LatPointId pNew(latPoints.size()); 289 const auto &point0 = lat(p0); 290 const auto &point1 = lat(p1); 291 BitVector bits(point0.bits); 292 bits |= point1.bits; 293 const ExprId e = addExp(kind, point0.exp, point1.exp, op); 294 latPoints.emplace_back(bits, e); 295 return pNew; 296 } 297 298 LatSetId Merger::conjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, 299 Operation *op) { 300 const LatSetId sNew = addSet(); 301 auto &setNew = latSets[sNew]; 302 for (const LatPointId p0 : set(s0)) 303 for (const LatPointId p1 : set(s1)) 304 setNew.push_back(conjLat(kind, p0, p1, op)); 305 return sNew; 306 } 307 308 LatSetId Merger::disjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, 309 Operation *op) { 310 const LatSetId sNew = conjSet(kind, s0, s1, op); 311 // Followed by all in s0. 312 latSets[sNew].append(latSets[s0]); 313 // Map binary 0-y to unary -y. 314 // TODO: move this if-else logic into buildLattices 315 if (kind == TensorExp::Kind::kSubF) 316 s1 = mapSet(TensorExp::Kind::kNegF, s1); 317 else if (kind == TensorExp::Kind::kSubC) 318 s1 = mapSet(TensorExp::Kind::kNegC, s1); 319 else if (kind == TensorExp::Kind::kSubI) 320 s1 = mapSet(TensorExp::Kind::kNegI, s1); 321 // Followed by all in s1. 322 latSets[sNew].append(latSets[s1]); 323 return sNew; 324 } 325 326 LatSetId Merger::combiSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, 327 Operation *orig, bool includeLeft, 328 TensorExp::Kind ltrans, Operation *opleft, 329 bool includeRight, TensorExp::Kind rtrans, 330 Operation *opright) { 331 const LatSetId sNew = conjSet(kind, s0, s1, orig); 332 // Left Region. 333 if (includeLeft) { 334 if (opleft) 335 s0 = mapSet(ltrans, s0, Value(), opleft); 336 latSets[sNew].append(latSets[s0]); 337 } 338 // Right Region. 339 if (includeRight) { 340 if (opright) 341 s1 = mapSet(rtrans, s1, Value(), opright); 342 latSets[sNew].append(latSets[s1]); 343 } 344 return sNew; 345 } 346 347 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v, 348 Operation *op) { 349 assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect); 350 const LatSetId sNew = addSet(); 351 auto &setNew = latSets[sNew]; 352 for (const LatPointId p : set(s0)) { 353 const auto &point = latPoints[p]; 354 setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op))); 355 } 356 return sNew; 357 } 358 359 LatSetId Merger::optimizeSet(LatSetId s0) { 360 const LatSetId sNew = addSet(); 361 auto &setNew = latSets[sNew]; 362 const auto &set0 = set(s0); 363 assert(!set0.empty()); 364 const LatPointId p0 = set0[0]; 365 for (const LatPointId p1 : set0) { 366 bool add = true; 367 if (p0 != p1) { 368 // Check whether this is a straightforward copy. 369 if (expIsTensor(latPoints[p1].exp, outTensor)) 370 continue; 371 // Check whether this conjunction is already covered. 372 for (const LatPointId p2 : setNew) { 373 assert(!latGT(p1, p2)); // Lj => Li would be bad 374 if (onlyDenseDiff(p2, p1)) { 375 add = false; 376 break; 377 } 378 } 379 assert(!add || latGT(p0, p1)); 380 } 381 if (add) 382 setNew.push_back(p1); 383 } 384 for (const LatPointId p : setNew) 385 latPoints[p].simple = simplifyCond(sNew, p); 386 return sNew; 387 } 388 389 BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { 390 // First determine if this lattice point is a *singleton*, i.e., 391 // the last point in a lattice, no other is less than this one. 392 bool isSingleton = true; 393 for (const LatPointId p1 : set(s0)) { 394 if (p0 != p1 && latGT(p0, p1)) { 395 isSingleton = false; 396 break; 397 } 398 } 399 400 BitVector simple(latPoints[p0].bits); 401 bool reset = isSingleton && hasAnySparse(simple); 402 const TensorLoopId be = simple.size(); 403 TensorLoopId offset = 0; // relative to the end 404 if (!reset) 405 // Starts resetting from a dense level, so that the first bit (if kept) 406 // is not undefined level-type. 407 for (unsigned b = 0; b < be; b++) { 408 if (simple[b] && isDenseDLT(getLvlType(TensorLoopId{b}))) { 409 offset = be - b - 1; // relative to the end 410 break; 411 } 412 } 413 414 // Now apply the two basic rules. We also iterate the bits reversely to always 415 // keep the rightmost bit (which could possibly be a synthetic tensor). 416 for (unsigned b = be - 1 - offset, i = 0; i < be; 417 b = b == 0 ? be - 1 : b - 1, i++) { 418 // Slice on dense level has `locate` property as well, and can be optimized. 419 if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) { 420 const auto dlt = getLvlType(b); 421 if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) && !isCompressedWithHiDLT(dlt)) { 422 if (reset) 423 simple.reset(b); 424 reset = true; 425 } 426 } 427 } 428 return simple; 429 } 430 431 bool Merger::latGT(LatPointId i, LatPointId j) const { 432 const BitVector &bitsi = lat(i).bits; 433 const BitVector &bitsj = lat(j).bits; 434 assert(bitsi.size() == bitsj.size()); 435 if (bitsi.count() > bitsj.count()) { 436 for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) 437 if (bitsj[b] && !bitsi[b]) 438 return false; 439 return true; 440 } 441 return false; 442 } 443 444 bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { 445 BitVector tmp(latPoints[j].bits); 446 tmp ^= latPoints[i].bits; 447 return !hasAnySparse(tmp); 448 } 449 450 bool Merger::expContainsTensor(ExprId e, TensorId t) const { 451 const auto &expr = exp(e); 452 // First we check `expIsTensor`. 453 if (expr.kind == TensorExp::Kind::kTensor) 454 return expr.tensor == t; 455 456 switch (getExpArity(expr.kind)) { 457 case ExpArity::kNullary: 458 return false; 459 case ExpArity::kUnary: { 460 const ExprId e0 = expr.children.e0; 461 return expContainsTensor(e0, t); 462 } 463 case ExpArity::kBinary: { 464 const ExprId e0 = expr.children.e0; 465 const ExprId e1 = expr.children.e1; 466 return expContainsTensor(e0, t) || expContainsTensor(e1, t); 467 } 468 } 469 llvm_unreachable("unexpected arity"); 470 } 471 472 bool Merger::hasNegateOnOut(ExprId e) const { 473 const auto &expr = exp(e); 474 switch (expr.kind) { 475 case TensorExp::Kind::kNegF: 476 case TensorExp::Kind::kNegC: 477 case TensorExp::Kind::kNegI: 478 return expContainsTensor(expr.children.e0, outTensor); 479 case TensorExp::Kind::kSubF: 480 case TensorExp::Kind::kSubC: 481 case TensorExp::Kind::kSubI: 482 return expContainsTensor(expr.children.e1, outTensor) || 483 hasNegateOnOut(expr.children.e0); 484 default: { 485 switch (getExpArity(expr.kind)) { 486 case ExpArity::kNullary: 487 return false; 488 case ExpArity::kUnary: 489 return hasNegateOnOut(expr.children.e0); 490 case ExpArity::kBinary: 491 return hasNegateOnOut(expr.children.e0) || 492 hasNegateOnOut(expr.children.e1); 493 } 494 } 495 } 496 llvm_unreachable("unexpected kind"); 497 } 498 499 bool Merger::isSingleCondition(TensorId t, ExprId e) const { 500 assert(isValidTensorId(t)); 501 const auto &expr = exp(e); 502 switch (expr.kind) { 503 // Leaf. 504 case TensorExp::Kind::kTensor: 505 return expr.tensor == t; 506 case TensorExp::Kind::kInvariant: 507 case TensorExp::Kind::kLoopVar: 508 return false; 509 // Unary operations. 510 case TensorExp::Kind::kAbsF: 511 case TensorExp::Kind::kAbsC: 512 case TensorExp::Kind::kAbsI: 513 case TensorExp::Kind::kCeilF: 514 case TensorExp::Kind::kFloorF: 515 case TensorExp::Kind::kSqrtF: 516 case TensorExp::Kind::kSqrtC: 517 case TensorExp::Kind::kExpm1F: 518 case TensorExp::Kind::kExpm1C: 519 case TensorExp::Kind::kLog1pF: 520 case TensorExp::Kind::kLog1pC: 521 case TensorExp::Kind::kSinF: 522 case TensorExp::Kind::kSinC: 523 case TensorExp::Kind::kTanhF: 524 case TensorExp::Kind::kTanhC: 525 case TensorExp::Kind::kNegF: 526 case TensorExp::Kind::kNegC: 527 case TensorExp::Kind::kNegI: 528 case TensorExp::Kind::kTruncF: 529 case TensorExp::Kind::kExtF: 530 case TensorExp::Kind::kCastFS: 531 case TensorExp::Kind::kCastFU: 532 case TensorExp::Kind::kCastSF: 533 case TensorExp::Kind::kCastUF: 534 case TensorExp::Kind::kCastS: 535 case TensorExp::Kind::kCastU: 536 case TensorExp::Kind::kCastIdx: 537 case TensorExp::Kind::kTruncI: 538 case TensorExp::Kind::kCIm: 539 case TensorExp::Kind::kCRe: 540 case TensorExp::Kind::kBitCast: 541 case TensorExp::Kind::kUnary: 542 return isSingleCondition(t, expr.children.e0); 543 case TensorExp::Kind::kBinaryBranch: 544 case TensorExp::Kind::kSelect: 545 return false; 546 // Binary operations. 547 case TensorExp::Kind::kDivF: // note: x / c only 548 case TensorExp::Kind::kDivC: 549 case TensorExp::Kind::kDivS: 550 case TensorExp::Kind::kDivU: 551 assert(!maybeZero(expr.children.e1)); 552 return isSingleCondition(t, expr.children.e0); 553 case TensorExp::Kind::kShrS: // note: x >> inv only 554 case TensorExp::Kind::kShrU: 555 case TensorExp::Kind::kShlI: 556 assert(isInvariant(expr.children.e1)); 557 return isSingleCondition(t, expr.children.e0); 558 case TensorExp::Kind::kMulF: 559 case TensorExp::Kind::kMulC: 560 case TensorExp::Kind::kMulI: 561 case TensorExp::Kind::kAndI: 562 case TensorExp::Kind::kReduce: 563 if (isSingleCondition(t, expr.children.e0)) 564 return isSingleCondition(t, expr.children.e1) || 565 isInvariant(expr.children.e1); 566 if (isSingleCondition(t, expr.children.e1)) 567 return isInvariant(expr.children.e0); 568 return false; 569 case TensorExp::Kind::kAddF: 570 case TensorExp::Kind::kAddC: 571 case TensorExp::Kind::kAddI: 572 return isSingleCondition(t, expr.children.e0) && 573 isSingleCondition(t, expr.children.e1); 574 case TensorExp::Kind::kSubF: 575 case TensorExp::Kind::kSubC: 576 case TensorExp::Kind::kSubI: 577 case TensorExp::Kind::kOrI: 578 case TensorExp::Kind::kXorI: 579 case TensorExp::Kind::kBinary: 580 return false; 581 } 582 llvm_unreachable("unexpected kind"); 583 } 584 585 bool Merger::hasAnySparse(const BitVector &bits) const { 586 for (TensorLoopId b : bits.set_bits()) { 587 const auto dlt = getLvlType(b); 588 if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || isCompressedWithHiDLT(dlt)) 589 return true; 590 } 591 return hasSparseIdxReduction(bits); 592 } 593 594 bool Merger::hasSparseIdxReduction(const BitVector &bits) const { 595 for (TensorLoopId b : bits.set_bits()) 596 if (isSparseLvlWithNonTrivialIdxExp(b)) 597 return true; 598 return false; 599 } 600 601 #ifndef NDEBUG 602 603 //===----------------------------------------------------------------------===// 604 // Print methods (for debugging). 605 //===----------------------------------------------------------------------===// 606 607 static const char *kindToOpSymbol(TensorExp::Kind kind) { 608 switch (kind) { 609 // Leaf. 610 case TensorExp::Kind::kTensor: 611 return "tensor"; 612 case TensorExp::Kind::kInvariant: 613 return "invariant"; 614 case TensorExp::Kind::kLoopVar: 615 return "index"; 616 // Unary operations. 617 case TensorExp::Kind::kAbsF: 618 case TensorExp::Kind::kAbsC: 619 case TensorExp::Kind::kAbsI: 620 return "abs"; 621 case TensorExp::Kind::kCeilF: 622 return "ceil"; 623 case TensorExp::Kind::kFloorF: 624 return "floor"; 625 case TensorExp::Kind::kSqrtF: 626 case TensorExp::Kind::kSqrtC: 627 return "sqrt"; 628 case TensorExp::Kind::kExpm1F: 629 case TensorExp::Kind::kExpm1C: 630 return "expm1"; 631 case TensorExp::Kind::kLog1pF: 632 case TensorExp::Kind::kLog1pC: 633 return "log1p"; 634 case TensorExp::Kind::kSinF: 635 case TensorExp::Kind::kSinC: 636 return "sin"; 637 case TensorExp::Kind::kTanhF: 638 case TensorExp::Kind::kTanhC: 639 return "tanh"; 640 case TensorExp::Kind::kNegF: 641 case TensorExp::Kind::kNegC: 642 case TensorExp::Kind::kNegI: 643 return "-"; 644 case TensorExp::Kind::kTruncF: 645 case TensorExp::Kind::kExtF: 646 case TensorExp::Kind::kCastFS: 647 case TensorExp::Kind::kCastFU: 648 case TensorExp::Kind::kCastSF: 649 case TensorExp::Kind::kCastUF: 650 case TensorExp::Kind::kCastS: 651 case TensorExp::Kind::kCastU: 652 case TensorExp::Kind::kCastIdx: 653 case TensorExp::Kind::kTruncI: 654 case TensorExp::Kind::kCIm: 655 return "complex.im"; 656 case TensorExp::Kind::kCRe: 657 return "complex.re"; 658 case TensorExp::Kind::kBitCast: 659 return "cast"; 660 case TensorExp::Kind::kBinaryBranch: 661 return "binary_branch"; 662 case TensorExp::Kind::kUnary: 663 return "unary"; 664 case TensorExp::Kind::kSelect: 665 return "select"; 666 // Binary operations. 667 case TensorExp::Kind::kMulF: 668 case TensorExp::Kind::kMulC: 669 case TensorExp::Kind::kMulI: 670 return "*"; 671 case TensorExp::Kind::kDivF: 672 case TensorExp::Kind::kDivC: 673 case TensorExp::Kind::kDivS: 674 case TensorExp::Kind::kDivU: 675 return "/"; 676 case TensorExp::Kind::kAddF: 677 case TensorExp::Kind::kAddC: 678 case TensorExp::Kind::kAddI: 679 return "+"; 680 case TensorExp::Kind::kSubF: 681 case TensorExp::Kind::kSubC: 682 case TensorExp::Kind::kSubI: 683 return "-"; 684 case TensorExp::Kind::kAndI: 685 return "&"; 686 case TensorExp::Kind::kOrI: 687 return "|"; 688 case TensorExp::Kind::kXorI: 689 return "^"; 690 case TensorExp::Kind::kShrS: 691 return "a>>"; 692 case TensorExp::Kind::kShrU: 693 return ">>"; 694 case TensorExp::Kind::kShlI: 695 return "<<"; 696 case TensorExp::Kind::kBinary: 697 return "binary"; 698 case TensorExp::Kind::kReduce: 699 return "reduce"; 700 } 701 llvm_unreachable("unexpected kind for symbol"); 702 } 703 704 void Merger::dumpExp(ExprId e) const { 705 const auto &expr = exp(e); 706 switch (expr.kind) { 707 // Leaf. 708 case TensorExp::Kind::kTensor: 709 if (expr.tensor == syntheticTensor) 710 llvm::dbgs() << "synthetic_"; 711 else if (expr.tensor == outTensor) 712 llvm::dbgs() << "output_"; 713 llvm::dbgs() << "tensor_" << expr.tensor; 714 break; 715 case TensorExp::Kind::kInvariant: 716 llvm::dbgs() << "invariant"; 717 break; 718 case TensorExp::Kind::kLoopVar: 719 llvm::dbgs() << "loopvar_" << expr.loop; 720 break; 721 // Unary operations. 722 case TensorExp::Kind::kAbsF: 723 case TensorExp::Kind::kAbsC: 724 case TensorExp::Kind::kAbsI: 725 case TensorExp::Kind::kCeilF: 726 case TensorExp::Kind::kFloorF: 727 case TensorExp::Kind::kSqrtF: 728 case TensorExp::Kind::kSqrtC: 729 case TensorExp::Kind::kExpm1F: 730 case TensorExp::Kind::kExpm1C: 731 case TensorExp::Kind::kLog1pF: 732 case TensorExp::Kind::kLog1pC: 733 case TensorExp::Kind::kSinF: 734 case TensorExp::Kind::kSinC: 735 case TensorExp::Kind::kTanhF: 736 case TensorExp::Kind::kTanhC: 737 case TensorExp::Kind::kNegF: 738 case TensorExp::Kind::kNegC: 739 case TensorExp::Kind::kNegI: 740 case TensorExp::Kind::kTruncF: 741 case TensorExp::Kind::kExtF: 742 case TensorExp::Kind::kCastFS: 743 case TensorExp::Kind::kCastFU: 744 case TensorExp::Kind::kCastSF: 745 case TensorExp::Kind::kCastUF: 746 case TensorExp::Kind::kCastS: 747 case TensorExp::Kind::kCastU: 748 case TensorExp::Kind::kCastIdx: 749 case TensorExp::Kind::kTruncI: 750 case TensorExp::Kind::kCIm: 751 case TensorExp::Kind::kCRe: 752 case TensorExp::Kind::kBitCast: 753 case TensorExp::Kind::kBinaryBranch: 754 case TensorExp::Kind::kUnary: 755 case TensorExp::Kind::kSelect: 756 llvm::dbgs() << kindToOpSymbol(expr.kind) << " "; 757 dumpExp(expr.children.e0); 758 break; 759 // Binary operations. 760 case TensorExp::Kind::kMulF: 761 case TensorExp::Kind::kMulC: 762 case TensorExp::Kind::kMulI: 763 case TensorExp::Kind::kDivF: 764 case TensorExp::Kind::kDivC: 765 case TensorExp::Kind::kDivS: 766 case TensorExp::Kind::kDivU: 767 case TensorExp::Kind::kAddF: 768 case TensorExp::Kind::kAddC: 769 case TensorExp::Kind::kAddI: 770 case TensorExp::Kind::kSubF: 771 case TensorExp::Kind::kSubC: 772 case TensorExp::Kind::kSubI: 773 case TensorExp::Kind::kAndI: 774 case TensorExp::Kind::kOrI: 775 case TensorExp::Kind::kXorI: 776 case TensorExp::Kind::kShrS: 777 case TensorExp::Kind::kShrU: 778 case TensorExp::Kind::kShlI: 779 case TensorExp::Kind::kBinary: 780 case TensorExp::Kind::kReduce: 781 llvm::dbgs() << "("; 782 dumpExp(expr.children.e0); 783 llvm::dbgs() << " " << kindToOpSymbol(expr.kind) << " "; 784 dumpExp(expr.children.e1); 785 llvm::dbgs() << ")"; 786 break; 787 } 788 } 789 790 void Merger::dumpLat(LatPointId p) const { 791 const auto &point = lat(p); 792 llvm::dbgs() << "lat("; 793 dumpBits(point.bits); 794 llvm::dbgs() << " :"; 795 dumpBits(point.simple); 796 llvm::dbgs() << " : "; 797 dumpExp(point.exp); 798 llvm::dbgs() << " )\n"; 799 } 800 801 void Merger::dumpSet(LatSetId s) const { 802 const auto &ss = set(s); 803 llvm::dbgs() << "{ #" << ss.size() << "\n"; 804 for (const LatPointId p : ss) { 805 llvm::dbgs() << " "; 806 dumpLat(p); 807 } 808 llvm::dbgs() << "}\n"; 809 } 810 811 void Merger::dumpBits(const BitVector &bits) const { 812 for (TensorLoopId b = 0, be = bits.size(); b < be; b++) { 813 if (bits[b]) { 814 const TensorId t = tensor(b); 815 const LoopId i = loop(b); 816 const auto dlt = lvlTypes[t][i]; 817 if (isLvlWithNonTrivialIdxExp(b)) 818 llvm::dbgs() << " DEP_" << t << "_" << i; 819 else 820 llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); 821 } 822 } 823 } 824 825 #endif // NDEBUG 826 827 //===----------------------------------------------------------------------===// 828 // Builder methods. 829 //===----------------------------------------------------------------------===// 830 831 LatSetId Merger::buildLattices(ExprId e, LoopId i) { 832 // NOTE: The `expr` reference will be invalidated by recursive calls 833 // (and any other method that may add new expressions); therefore, the 834 // code below must make sure to copy fields of `expr` into local variables 835 // before making any recursive calls. 836 const auto &expr = exp(e); 837 const TensorExp::Kind kind = expr.kind; 838 switch (kind) { 839 // Leaf. 840 case TensorExp::Kind::kTensor: 841 case TensorExp::Kind::kInvariant: 842 case TensorExp::Kind::kLoopVar: { 843 // Either the loop-var is really used in the tensor expression, or it is 844 // set to the undefined loop-var in that level. An invariant expression, 845 // a proper index value, and a truly dynamic sparse output tensor are set 846 // to a synthetic tensor with undefined indices only to ensure the 847 // iteration space is not skipped as a result of their contents. 848 const LatSetId s = addSet(); 849 TensorId t = syntheticTensor; 850 if (kind == TensorExp::Kind::kTensor) { 851 t = expr.tensor; 852 if (hasSparseOut && t == outTensor) 853 t = syntheticTensor; 854 } 855 latSets[s].push_back(addLat(t, i, e)); 856 return s; 857 } 858 // Unary operations. 859 case TensorExp::Kind::kAbsF: 860 case TensorExp::Kind::kAbsC: 861 case TensorExp::Kind::kAbsI: 862 case TensorExp::Kind::kCeilF: 863 case TensorExp::Kind::kFloorF: 864 case TensorExp::Kind::kSqrtF: 865 case TensorExp::Kind::kSqrtC: 866 case TensorExp::Kind::kExpm1F: 867 case TensorExp::Kind::kExpm1C: 868 case TensorExp::Kind::kLog1pF: 869 case TensorExp::Kind::kLog1pC: 870 case TensorExp::Kind::kSinF: 871 case TensorExp::Kind::kSinC: 872 case TensorExp::Kind::kTanhF: 873 case TensorExp::Kind::kTanhC: 874 case TensorExp::Kind::kNegF: 875 case TensorExp::Kind::kNegC: 876 case TensorExp::Kind::kNegI: 877 case TensorExp::Kind::kTruncF: 878 case TensorExp::Kind::kExtF: 879 case TensorExp::Kind::kCastFS: 880 case TensorExp::Kind::kCastFU: 881 case TensorExp::Kind::kCastSF: 882 case TensorExp::Kind::kCastUF: 883 case TensorExp::Kind::kCastS: 884 case TensorExp::Kind::kCastU: 885 case TensorExp::Kind::kCastIdx: 886 case TensorExp::Kind::kTruncI: 887 case TensorExp::Kind::kCIm: 888 case TensorExp::Kind::kCRe: 889 case TensorExp::Kind::kBitCast: 890 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 891 // lattice set of the operand through the operator into a new set. 892 // 893 // -y|!y | y | 894 // --+---+---+ 895 // | 0 |-y | 896 { 897 const ExprId e0 = expr.children.e0; 898 const Value v = expr.val; 899 return mapSet(kind, buildLattices(e0, i), v); 900 } 901 case TensorExp::Kind::kBinaryBranch: 902 case TensorExp::Kind::kSelect: 903 // The left or right half of a binary operation which has already 904 // been split into separate operations for each region. 905 { 906 const ExprId e0 = expr.children.e0; 907 Operation *const op = expr.op; 908 return mapSet(kind, buildLattices(e0, i), Value(), op); 909 } 910 case TensorExp::Kind::kUnary: 911 // A custom unary operation. 912 // 913 // op y| !y | y | 914 // ----+----------+------------+ 915 // | absent() | present(y) | 916 { 917 const ExprId e0 = expr.children.e0; 918 UnaryOp unop = cast<UnaryOp>(expr.op); 919 const LatSetId child0 = buildLattices(e0, i); 920 Region &absentRegion = unop.getAbsentRegion(); 921 if (absentRegion.empty()) { 922 // Simple mapping over existing values. 923 return mapSet(kind, child0, Value(), unop); 924 } 925 // Use a disjunction with `unop` on the left and the absent value as an 926 // invariant on the right. 927 Block &absentBlock = absentRegion.front(); 928 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); 929 const Value absentVal = absentYield.getResult(); 930 const ExprId rhs = addInvariantExp(absentVal); 931 return disjSet(kind, child0, buildLattices(rhs, i), unop); 932 } 933 // Binary operations. 934 case TensorExp::Kind::kMulF: 935 case TensorExp::Kind::kMulC: 936 case TensorExp::Kind::kMulI: 937 case TensorExp::Kind::kAndI: 938 // A multiplicative operation only needs to be performed 939 // for the conjunction of sparse iteration spaces. 940 // 941 // x*y|!y | y | 942 // ---+---+---+ 943 // !x | 0 | 0 | 944 // x | 0 |x*y| 945 // 946 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. 947 { 948 const ExprId e0 = expr.children.e0; 949 const ExprId e1 = expr.children.e1; 950 return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i)); 951 } 952 case TensorExp::Kind::kDivF: 953 case TensorExp::Kind::kDivC: 954 case TensorExp::Kind::kDivS: 955 case TensorExp::Kind::kDivU: 956 // A division is tricky, since 0/0, 0/c, c/0 all have 957 // specific outcomes for floating-point and integers. 958 // Thus, we need to traverse the full iteration space. 959 // 960 // x/y|!y | y | 961 // ---+---+---+ 962 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 963 // x |x/0|x/y| INT: x/0=exception for any x 964 // 965 // TODO: for now we "fixed" this by only accepting x/c cases 966 // during expression building, so that the conjunction 967 // rules applies (viz. x/c = x*(1/c) as far as lattice 968 // construction is concerned). 969 { 970 const ExprId e0 = expr.children.e0; 971 const ExprId e1 = expr.children.e1; 972 assert(!maybeZero(e1)); 973 return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i)); 974 } 975 case TensorExp::Kind::kAddF: 976 case TensorExp::Kind::kAddC: 977 case TensorExp::Kind::kAddI: 978 case TensorExp::Kind::kSubF: 979 case TensorExp::Kind::kSubC: 980 case TensorExp::Kind::kSubI: 981 case TensorExp::Kind::kOrI: 982 case TensorExp::Kind::kXorI: 983 // An additive operation needs to be performed 984 // for the disjunction of sparse iteration spaces. 985 // 986 // x+y|!y | y | x-y|!y | y | 987 // ---+---+---+ ---+---+---+ 988 // !x | 0 | y | !x | 0 |-y | 989 // x | x |x+y| x | x |x-y| 990 { 991 const ExprId e0 = expr.children.e0; 992 const ExprId e1 = expr.children.e1; 993 return disjSet(kind, buildLattices(e0, i), buildLattices(e1, i)); 994 } 995 case TensorExp::Kind::kShrS: 996 case TensorExp::Kind::kShrU: 997 case TensorExp::Kind::kShlI: 998 // A shift operation by an invariant amount (viz. tensor expressions 999 // can only occur at the left-hand-side of the operator) can be handled 1000 // with the conjuction rule. 1001 { 1002 const ExprId e0 = expr.children.e0; 1003 const ExprId e1 = expr.children.e1; 1004 assert(isInvariant(e1)); 1005 return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i)); 1006 } 1007 case TensorExp::Kind::kBinary: 1008 // A custom binary operation. 1009 // 1010 // x op y| !y | y | 1011 // ------+---------+--------------+ 1012 // !x | empty | right(y) | 1013 // x | left(x) | overlap(x,y) | 1014 { 1015 const ExprId e0 = expr.children.e0; 1016 const ExprId e1 = expr.children.e1; 1017 BinaryOp binop = cast<BinaryOp>(expr.op); 1018 const LatSetId child0 = buildLattices(e0, i); 1019 const LatSetId child1 = buildLattices(e1, i); 1020 Region &leftRegion = binop.getLeftRegion(); 1021 Region &rightRegion = binop.getRightRegion(); 1022 // Left Region. 1023 Operation *leftYield = nullptr; 1024 if (!leftRegion.empty()) { 1025 Block &leftBlock = leftRegion.front(); 1026 leftYield = leftBlock.getTerminator(); 1027 } 1028 // Right Region. 1029 Operation *rightYield = nullptr; 1030 if (!rightRegion.empty()) { 1031 Block &rightBlock = rightRegion.front(); 1032 rightYield = rightBlock.getTerminator(); 1033 } 1034 bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty(); 1035 bool includeRight = binop.getRightIdentity() || !rightRegion.empty(); 1036 return combiSet(TensorExp::Kind::kBinary, child0, child1, binop, 1037 includeLeft, TensorExp::Kind::kBinaryBranch, leftYield, 1038 includeRight, TensorExp::Kind::kBinaryBranch, rightYield); 1039 } 1040 case TensorExp::Kind::kReduce: 1041 // A custom reduce operation. 1042 { 1043 const ExprId e0 = expr.children.e0; 1044 const ExprId e1 = expr.children.e1; 1045 Operation *const op = expr.op; 1046 return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i), op); 1047 } 1048 } 1049 llvm_unreachable("unexpected expression kind"); 1050 } 1051 1052 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 1053 // Build the linalg semantics backward from yield. 1054 Operation *yield = op.getRegion().front().getTerminator(); 1055 assert(isa<linalg::YieldOp>(yield)); 1056 return buildTensorExp(op, yield->getOperand(0)); 1057 } 1058 1059 /// Only returns false if we are certain this is a nonzero. 1060 bool Merger::maybeZero(ExprId e) const { 1061 const auto &expr = exp(e); 1062 if (expr.kind == TensorExp::Kind::kInvariant) { 1063 if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) { 1064 ArrayAttr arrayAttr = c.getValue(); 1065 return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() && 1066 cast<FloatAttr>(arrayAttr[1]).getValue().isZero(); 1067 } 1068 if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>()) 1069 return c.value() == 0; 1070 if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>()) 1071 return c.value().isZero(); 1072 } 1073 return true; 1074 } 1075 1076 Type Merger::inferType(ExprId e, Value src) const { 1077 // Obtain the destination type from the cast node. 1078 Type dtp = exp(e).val.getType(); 1079 // Inspect source type. For vector types, apply the same 1080 // vectorization to the destination type. 1081 if (auto vtp = dyn_cast<VectorType>(src.getType())) 1082 return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); 1083 return dtp; 1084 } 1085 1086 /// Ensures that sparse compiler can generate code for expression. 1087 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) { 1088 // Arguments are always admissible. 1089 if (isa<BlockArgument>(v)) 1090 return true; 1091 // Accept index anywhere. 1092 Operation *def = v.getDefiningOp(); 1093 if (isa<linalg::IndexOp>(def)) 1094 return true; 1095 // Operation defined outside branch. 1096 if (def->getBlock() != block) 1097 return def->getBlock() != op->getBlock(); // invariant? 1098 // Operation defined within branch. Anything is accepted, 1099 // as long as all subexpressions are admissible. 1100 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) 1101 if (!isAdmissibleBranchExp(op, block, def->getOperand(i))) 1102 return false; 1103 return true; 1104 } 1105 1106 /// Ensures that sparse compiler can generate code for branch. 1107 static bool isAdmissibleBranch(Operation *op, Region ®ion) { 1108 if (region.empty()) 1109 return true; 1110 // Build the semi-ring branch semantics backward from yield. 1111 Operation *yield = region.front().getTerminator(); 1112 assert(isa<YieldOp>(yield)); 1113 return isAdmissibleBranchExp(op, ®ion.front(), yield->getOperand(0)); 1114 } 1115 1116 std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 1117 if (auto arg = dyn_cast<BlockArgument>(v)) { 1118 const TensorId tid = makeTensorId(arg.getArgNumber()); 1119 // Any argument of the generic op that is not marked as a scalar 1120 // argument is considered a tensor, indexed by the implicit loop 1121 // bounds. This includes rank-0 tensor arguments. 1122 if (arg.getOwner()->getParentOp() == op) { 1123 OpOperand &t = op->getOpOperand(tid); 1124 if (!op.isScalar(&t)) 1125 return addTensorExp(tid); 1126 v = t.get(); // get scalar value 1127 } 1128 // Any other argument (marked as scalar argument for the generic op 1129 // or belonging to an enveloping op) is considered invariant. 1130 return addInvariantExp(v); 1131 } 1132 // Something defined outside is invariant. 1133 Operation *def = v.getDefiningOp(); 1134 if (def->getBlock() != &op.getRegion().front()) 1135 return addInvariantExp(v); 1136 // Construct index operations. 1137 if (def->getNumOperands() == 0) { 1138 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 1139 return addLoopVarExp(makeLoopId(indexOp.getDim())); 1140 } 1141 // Construct unary operations if subexpression can be built. 1142 if (def->getNumOperands() == 1) { 1143 const auto x = buildTensorExp(op, def->getOperand(0)); 1144 if (x.has_value()) { 1145 const ExprId e = *x; 1146 if (isa<math::AbsFOp>(def)) 1147 return addExp(TensorExp::Kind::kAbsF, e); 1148 if (isa<complex::AbsOp>(def)) 1149 return addExp(TensorExp::Kind::kAbsC, e); 1150 if (isa<math::AbsIOp>(def)) 1151 return addExp(TensorExp::Kind::kAbsI, e); 1152 if (isa<math::CeilOp>(def)) 1153 return addExp(TensorExp::Kind::kCeilF, e); 1154 if (isa<math::FloorOp>(def)) 1155 return addExp(TensorExp::Kind::kFloorF, e); 1156 if (isa<math::SqrtOp>(def)) 1157 return addExp(TensorExp::Kind::kSqrtF, e); 1158 if (isa<complex::SqrtOp>(def)) 1159 return addExp(TensorExp::Kind::kSqrtC, e); 1160 if (isa<math::ExpM1Op>(def)) 1161 return addExp(TensorExp::Kind::kExpm1F, e); 1162 if (isa<complex::Expm1Op>(def)) 1163 return addExp(TensorExp::Kind::kExpm1C, e); 1164 if (isa<math::Log1pOp>(def)) 1165 return addExp(TensorExp::Kind::kLog1pF, e); 1166 if (isa<complex::Log1pOp>(def)) 1167 return addExp(TensorExp::Kind::kLog1pC, e); 1168 if (isa<math::SinOp>(def)) 1169 return addExp(TensorExp::Kind::kSinF, e); 1170 if (isa<complex::SinOp>(def)) 1171 return addExp(TensorExp::Kind::kSinC, e); 1172 if (isa<math::TanhOp>(def)) 1173 return addExp(TensorExp::Kind::kTanhF, e); 1174 if (isa<complex::TanhOp>(def)) 1175 return addExp(TensorExp::Kind::kTanhC, e); 1176 if (isa<arith::NegFOp>(def)) 1177 return addExp(TensorExp::Kind::kNegF, e); // no negi in std 1178 if (isa<complex::NegOp>(def)) 1179 return addExp(TensorExp::Kind::kNegC, e); 1180 if (isa<arith::TruncFOp>(def)) 1181 return addExp(TensorExp::Kind::kTruncF, e, v); 1182 if (isa<arith::ExtFOp>(def)) 1183 return addExp(TensorExp::Kind::kExtF, e, v); 1184 if (isa<arith::FPToSIOp>(def)) 1185 return addExp(TensorExp::Kind::kCastFS, e, v); 1186 if (isa<arith::FPToUIOp>(def)) 1187 return addExp(TensorExp::Kind::kCastFU, e, v); 1188 if (isa<arith::SIToFPOp>(def)) 1189 return addExp(TensorExp::Kind::kCastSF, e, v); 1190 if (isa<arith::UIToFPOp>(def)) 1191 return addExp(TensorExp::Kind::kCastUF, e, v); 1192 if (isa<arith::ExtSIOp>(def)) 1193 return addExp(TensorExp::Kind::kCastS, e, v); 1194 if (isa<arith::ExtUIOp>(def)) 1195 return addExp(TensorExp::Kind::kCastU, e, v); 1196 if (isa<arith::IndexCastOp>(def)) 1197 return addExp(TensorExp::Kind::kCastIdx, e, v); 1198 if (isa<arith::TruncIOp>(def)) 1199 return addExp(TensorExp::Kind::kTruncI, e, v); 1200 if (isa<complex::ImOp>(def)) 1201 return addExp(TensorExp::Kind::kCIm, e); 1202 if (isa<complex::ReOp>(def)) 1203 return addExp(TensorExp::Kind::kCRe, e); 1204 if (isa<arith::BitcastOp>(def)) 1205 return addExp(TensorExp::Kind::kBitCast, e, v); 1206 if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) { 1207 if (isAdmissibleBranch(unop, unop.getPresentRegion()) && 1208 isAdmissibleBranch(unop, unop.getAbsentRegion())) 1209 return addExp(TensorExp::Kind::kUnary, e, Value(), def); 1210 } 1211 if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) { 1212 if (isAdmissibleBranch(selop, selop.getRegion())) 1213 return addExp(TensorExp::Kind::kSelect, e, Value(), def); 1214 } 1215 } 1216 } 1217 // Construct binary operations if subexpressions can be built. 1218 // See buildLattices() for an explanation of rejecting certain 1219 // division and shift operations. 1220 if (def->getNumOperands() == 2) { 1221 const auto x = buildTensorExp(op, def->getOperand(0)); 1222 const auto y = buildTensorExp(op, def->getOperand(1)); 1223 if (x.has_value() && y.has_value()) { 1224 const ExprId e0 = *x; 1225 const ExprId e1 = *y; 1226 if (isa<arith::MulFOp>(def)) 1227 return addExp(TensorExp::Kind::kMulF, e0, e1); 1228 if (isa<complex::MulOp>(def)) 1229 return addExp(TensorExp::Kind::kMulC, e0, e1); 1230 if (isa<arith::MulIOp>(def)) 1231 return addExp(TensorExp::Kind::kMulI, e0, e1); 1232 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 1233 return addExp(TensorExp::Kind::kDivF, e0, e1); 1234 if (isa<complex::DivOp>(def) && !maybeZero(e1)) 1235 return addExp(TensorExp::Kind::kDivC, e0, e1); 1236 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 1237 return addExp(TensorExp::Kind::kDivS, e0, e1); 1238 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 1239 return addExp(TensorExp::Kind::kDivU, e0, e1); 1240 if (isa<arith::AddFOp>(def)) 1241 return addExp(TensorExp::Kind::kAddF, e0, e1); 1242 if (isa<complex::AddOp>(def)) 1243 return addExp(TensorExp::Kind::kAddC, e0, e1); 1244 if (isa<arith::AddIOp>(def)) 1245 return addExp(TensorExp::Kind::kAddI, e0, e1); 1246 if (isa<arith::SubFOp>(def)) 1247 return addExp(TensorExp::Kind::kSubF, e0, e1); 1248 if (isa<complex::SubOp>(def)) 1249 return addExp(TensorExp::Kind::kSubC, e0, e1); 1250 if (isa<arith::SubIOp>(def)) 1251 return addExp(TensorExp::Kind::kSubI, e0, e1); 1252 if (isa<arith::AndIOp>(def)) 1253 return addExp(TensorExp::Kind::kAndI, e0, e1); 1254 if (isa<arith::OrIOp>(def)) 1255 return addExp(TensorExp::Kind::kOrI, e0, e1); 1256 if (isa<arith::XOrIOp>(def)) 1257 return addExp(TensorExp::Kind::kXorI, e0, e1); 1258 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 1259 return addExp(TensorExp::Kind::kShrS, e0, e1); 1260 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 1261 return addExp(TensorExp::Kind::kShrU, e0, e1); 1262 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 1263 return addExp(TensorExp::Kind::kShlI, e0, e1); 1264 if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) { 1265 if (isAdmissibleBranch(binop, binop.getOverlapRegion()) && 1266 (binop.getLeftIdentity() || 1267 isAdmissibleBranch(binop, binop.getLeftRegion())) && 1268 (binop.getRightIdentity() || 1269 isAdmissibleBranch(binop, binop.getRightRegion()))) 1270 return addExp(TensorExp::Kind::kBinary, e0, e1, def); 1271 } 1272 } 1273 } 1274 // Construct ternary operations if subexpressions can be built. 1275 if (def->getNumOperands() == 3) { 1276 const auto x = buildTensorExp(op, def->getOperand(0)); 1277 const auto y = buildTensorExp(op, def->getOperand(1)); 1278 const auto z = buildTensorExp(op, def->getOperand(2)); 1279 if (x.has_value() && y.has_value() && z.has_value()) { 1280 const ExprId e0 = *x; 1281 const ExprId e1 = *y; 1282 if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) { 1283 if (isAdmissibleBranch(redop, redop.getRegion())) 1284 return addExp(TensorExp::Kind::kReduce, e0, e1, def); 1285 } 1286 } 1287 } 1288 // Cannot build. 1289 return std::nullopt; 1290 } 1291 1292 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, 1293 ValueRange vals) { 1294 // Make a clone of overlap region. 1295 Region tmpRegion; 1296 IRMapping mapper; 1297 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); 1298 Block &clonedBlock = tmpRegion.front(); 1299 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); 1300 // Merge cloned block and return yield value. 1301 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1302 rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals); 1303 Value val = clonedYield.getResult(); 1304 rewriter.eraseOp(clonedYield); 1305 rewriter.eraseOp(placeholder); 1306 return val; 1307 } 1308 1309 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, 1310 Operation *op, Value v0) { 1311 if (!v0) 1312 // Empty input value must be propagated. 1313 return Value(); 1314 UnaryOp unop = cast<UnaryOp>(op); 1315 Region &presentRegion = unop.getPresentRegion(); 1316 if (presentRegion.empty()) 1317 // Uninitialized Value() will be interpreted as missing data in the 1318 // output. 1319 return Value(); 1320 return insertYieldOp(rewriter, loc, presentRegion, {v0}); 1321 } 1322 1323 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, 1324 Operation *op, Value v0, Value v1) { 1325 if (!v0 || !v1) 1326 // Empty input values must be propagated. 1327 return Value(); 1328 BinaryOp binop = cast<BinaryOp>(op); 1329 Region &overlapRegion = binop.getOverlapRegion(); 1330 if (overlapRegion.empty()) 1331 // Uninitialized Value() will be interpreted as missing data in the 1332 // output. 1333 return Value(); 1334 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); 1335 } 1336 1337 Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, 1338 Value v1) const { 1339 const auto &expr = exp(e); 1340 switch (expr.kind) { 1341 // Leaf. 1342 case TensorExp::Kind::kTensor: 1343 case TensorExp::Kind::kInvariant: 1344 case TensorExp::Kind::kLoopVar: 1345 llvm_unreachable("unexpected non-op"); 1346 // Unary operations. 1347 case TensorExp::Kind::kAbsF: 1348 return rewriter.create<math::AbsFOp>(loc, v0); 1349 case TensorExp::Kind::kAbsC: { 1350 auto type = cast<ComplexType>(v0.getType()); 1351 auto eltType = cast<FloatType>(type.getElementType()); 1352 return rewriter.create<complex::AbsOp>(loc, eltType, v0); 1353 } 1354 case TensorExp::Kind::kAbsI: 1355 return rewriter.create<math::AbsIOp>(loc, v0); 1356 case TensorExp::Kind::kCeilF: 1357 return rewriter.create<math::CeilOp>(loc, v0); 1358 case TensorExp::Kind::kFloorF: 1359 return rewriter.create<math::FloorOp>(loc, v0); 1360 case TensorExp::Kind::kSqrtF: 1361 return rewriter.create<math::SqrtOp>(loc, v0); 1362 case TensorExp::Kind::kSqrtC: 1363 return rewriter.create<complex::SqrtOp>(loc, v0); 1364 case TensorExp::Kind::kExpm1F: 1365 return rewriter.create<math::ExpM1Op>(loc, v0); 1366 case TensorExp::Kind::kExpm1C: 1367 return rewriter.create<complex::Expm1Op>(loc, v0); 1368 case TensorExp::Kind::kLog1pF: 1369 return rewriter.create<math::Log1pOp>(loc, v0); 1370 case TensorExp::Kind::kLog1pC: 1371 return rewriter.create<complex::Log1pOp>(loc, v0); 1372 case TensorExp::Kind::kSinF: 1373 return rewriter.create<math::SinOp>(loc, v0); 1374 case TensorExp::Kind::kSinC: 1375 return rewriter.create<complex::SinOp>(loc, v0); 1376 case TensorExp::Kind::kTanhF: 1377 return rewriter.create<math::TanhOp>(loc, v0); 1378 case TensorExp::Kind::kTanhC: 1379 return rewriter.create<complex::TanhOp>(loc, v0); 1380 case TensorExp::Kind::kNegF: 1381 return rewriter.create<arith::NegFOp>(loc, v0); 1382 case TensorExp::Kind::kNegC: 1383 return rewriter.create<complex::NegOp>(loc, v0); 1384 case TensorExp::Kind::kNegI: // no negi in std 1385 return rewriter.create<arith::SubIOp>( 1386 loc, 1387 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 1388 rewriter.getZeroAttr(v0.getType())), 1389 v0); 1390 case TensorExp::Kind::kTruncF: 1391 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 1392 case TensorExp::Kind::kExtF: 1393 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 1394 case TensorExp::Kind::kCastFS: 1395 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 1396 case TensorExp::Kind::kCastFU: 1397 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 1398 case TensorExp::Kind::kCastSF: 1399 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 1400 case TensorExp::Kind::kCastUF: 1401 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 1402 case TensorExp::Kind::kCastS: 1403 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 1404 case TensorExp::Kind::kCastU: 1405 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 1406 case TensorExp::Kind::kCastIdx: 1407 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 1408 case TensorExp::Kind::kTruncI: 1409 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 1410 case TensorExp::Kind::kCIm: { 1411 auto type = cast<ComplexType>(v0.getType()); 1412 auto eltType = cast<FloatType>(type.getElementType()); 1413 return rewriter.create<complex::ImOp>(loc, eltType, v0); 1414 } 1415 case TensorExp::Kind::kCRe: { 1416 auto type = cast<ComplexType>(v0.getType()); 1417 auto eltType = cast<FloatType>(type.getElementType()); 1418 return rewriter.create<complex::ReOp>(loc, eltType, v0); 1419 } 1420 case TensorExp::Kind::kBitCast: 1421 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 1422 // Binary operations. 1423 case TensorExp::Kind::kMulF: 1424 return rewriter.create<arith::MulFOp>(loc, v0, v1); 1425 case TensorExp::Kind::kMulC: 1426 return rewriter.create<complex::MulOp>(loc, v0, v1); 1427 case TensorExp::Kind::kMulI: 1428 return rewriter.create<arith::MulIOp>(loc, v0, v1); 1429 case TensorExp::Kind::kDivF: 1430 return rewriter.create<arith::DivFOp>(loc, v0, v1); 1431 case TensorExp::Kind::kDivC: 1432 return rewriter.create<complex::DivOp>(loc, v0, v1); 1433 case TensorExp::Kind::kDivS: 1434 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 1435 case TensorExp::Kind::kDivU: 1436 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 1437 case TensorExp::Kind::kAddF: 1438 return rewriter.create<arith::AddFOp>(loc, v0, v1); 1439 case TensorExp::Kind::kAddC: 1440 return rewriter.create<complex::AddOp>(loc, v0, v1); 1441 case TensorExp::Kind::kAddI: 1442 return rewriter.create<arith::AddIOp>(loc, v0, v1); 1443 case TensorExp::Kind::kSubF: 1444 return rewriter.create<arith::SubFOp>(loc, v0, v1); 1445 case TensorExp::Kind::kSubC: 1446 return rewriter.create<complex::SubOp>(loc, v0, v1); 1447 case TensorExp::Kind::kSubI: 1448 return rewriter.create<arith::SubIOp>(loc, v0, v1); 1449 case TensorExp::Kind::kAndI: 1450 return rewriter.create<arith::AndIOp>(loc, v0, v1); 1451 case TensorExp::Kind::kOrI: 1452 return rewriter.create<arith::OrIOp>(loc, v0, v1); 1453 case TensorExp::Kind::kXorI: 1454 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 1455 case TensorExp::Kind::kShrS: 1456 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 1457 case TensorExp::Kind::kShrU: 1458 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 1459 case TensorExp::Kind::kShlI: 1460 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 1461 case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic. 1462 return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(), 1463 {v0}); 1464 case TensorExp::Kind::kUnary: 1465 return buildUnaryPresent(rewriter, loc, expr.op, v0); 1466 case TensorExp::Kind::kSelect: 1467 return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(), 1468 {v0}); 1469 case TensorExp::Kind::kBinary: 1470 return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1); 1471 case TensorExp::Kind::kReduce: { 1472 ReduceOp redOp = cast<ReduceOp>(expr.op); 1473 return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1}); 1474 } 1475 } 1476 llvm_unreachable("unexpected expression kind in build"); 1477 } 1478 1479 } // namespace sparse_tensor 1480 } // namespace mlir 1481