1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arith/IR/Arith.h" 11 #include "mlir/Dialect/Complex/IR/Complex.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 15 #include "mlir/IR/Operation.h" 16 #include "llvm/Support/Debug.h" 17 #include <optional> 18 19 namespace mlir { 20 namespace sparse_tensor { 21 22 enum class ExpArity { 23 kNullary, 24 kUnary, 25 kBinary, 26 }; 27 28 static ExpArity getExpArity(TensorExp::Kind k) { 29 switch (k) { 30 // Leaf. 31 case TensorExp::Kind::kTensor: 32 case TensorExp::Kind::kInvariant: 33 case TensorExp::Kind::kLoopVar: 34 return ExpArity::kNullary; 35 case TensorExp::Kind::kAbsF: 36 case TensorExp::Kind::kAbsC: 37 case TensorExp::Kind::kAbsI: 38 case TensorExp::Kind::kCeilF: 39 case TensorExp::Kind::kFloorF: 40 case TensorExp::Kind::kSqrtF: 41 case TensorExp::Kind::kSqrtC: 42 case TensorExp::Kind::kExpm1F: 43 case TensorExp::Kind::kExpm1C: 44 case TensorExp::Kind::kLog1pF: 45 case TensorExp::Kind::kLog1pC: 46 case TensorExp::Kind::kSinF: 47 case TensorExp::Kind::kSinC: 48 case TensorExp::Kind::kTanhF: 49 case TensorExp::Kind::kTanhC: 50 case TensorExp::Kind::kTruncF: 51 case TensorExp::Kind::kExtF: 52 case TensorExp::Kind::kCastFS: 53 case TensorExp::Kind::kCastFU: 54 case TensorExp::Kind::kCastSF: 55 case TensorExp::Kind::kCastUF: 56 case TensorExp::Kind::kCastS: 57 case TensorExp::Kind::kCastU: 58 case TensorExp::Kind::kCastIdx: 59 case TensorExp::Kind::kTruncI: 60 case TensorExp::Kind::kCIm: 61 case TensorExp::Kind::kCRe: 62 case TensorExp::Kind::kBitCast: 63 case TensorExp::Kind::kBinaryBranch: 64 case TensorExp::Kind::kUnary: 65 case TensorExp::Kind::kSelect: 66 case TensorExp::Kind::kNegF: 67 case TensorExp::Kind::kNegC: 68 case TensorExp::Kind::kNegI: 69 return ExpArity::kUnary; 70 // Binary operations. 71 case TensorExp::Kind::kDivF: 72 case TensorExp::Kind::kDivC: 73 case TensorExp::Kind::kDivS: 74 case TensorExp::Kind::kDivU: 75 case TensorExp::Kind::kShrS: 76 case TensorExp::Kind::kShrU: 77 case TensorExp::Kind::kShlI: 78 case TensorExp::Kind::kMulF: 79 case TensorExp::Kind::kMulC: 80 case TensorExp::Kind::kMulI: 81 case TensorExp::Kind::kAndI: 82 case TensorExp::Kind::kAddF: 83 case TensorExp::Kind::kAddC: 84 case TensorExp::Kind::kAddI: 85 case TensorExp::Kind::kOrI: 86 case TensorExp::Kind::kXorI: 87 case TensorExp::Kind::kBinary: 88 case TensorExp::Kind::kReduce: 89 case TensorExp::Kind::kSubF: 90 case TensorExp::Kind::kSubC: 91 case TensorExp::Kind::kSubI: 92 return ExpArity::kBinary; 93 } 94 llvm_unreachable("unexpected kind"); 95 } 96 97 //===----------------------------------------------------------------------===// 98 // Constructors. 99 //===----------------------------------------------------------------------===// 100 101 TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v, 102 Operation *o) 103 : kind(k), val(v), op(o) { 104 switch (kind) { 105 // Leaf. 106 case TensorExp::Kind::kTensor: 107 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 108 tensor = x; 109 return; 110 case TensorExp::Kind::kInvariant: 111 assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o); 112 return; 113 case TensorExp::Kind::kLoopVar: 114 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 115 loop = x; 116 return; 117 // Unary operations. 118 case TensorExp::Kind::kAbsF: 119 case TensorExp::Kind::kAbsC: 120 case TensorExp::Kind::kAbsI: 121 case TensorExp::Kind::kCeilF: 122 case TensorExp::Kind::kFloorF: 123 case TensorExp::Kind::kSqrtF: 124 case TensorExp::Kind::kSqrtC: 125 case TensorExp::Kind::kExpm1F: 126 case TensorExp::Kind::kExpm1C: 127 case TensorExp::Kind::kLog1pF: 128 case TensorExp::Kind::kLog1pC: 129 case TensorExp::Kind::kSinF: 130 case TensorExp::Kind::kSinC: 131 case TensorExp::Kind::kTanhF: 132 case TensorExp::Kind::kTanhC: 133 case TensorExp::Kind::kNegF: 134 case TensorExp::Kind::kNegC: 135 case TensorExp::Kind::kNegI: 136 case TensorExp::Kind::kCIm: 137 case TensorExp::Kind::kCRe: 138 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 139 children.e0 = x; 140 children.e1 = y; 141 return; 142 case TensorExp::Kind::kTruncF: 143 case TensorExp::Kind::kExtF: 144 case TensorExp::Kind::kCastFS: 145 case TensorExp::Kind::kCastFU: 146 case TensorExp::Kind::kCastSF: 147 case TensorExp::Kind::kCastUF: 148 case TensorExp::Kind::kCastS: 149 case TensorExp::Kind::kCastU: 150 case TensorExp::Kind::kCastIdx: 151 case TensorExp::Kind::kTruncI: 152 case TensorExp::Kind::kBitCast: 153 assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o); 154 children.e0 = x; 155 children.e1 = y; 156 return; 157 case TensorExp::Kind::kBinaryBranch: 158 case TensorExp::Kind::kSelect: 159 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o); 160 children.e0 = x; 161 children.e1 = y; 162 return; 163 case TensorExp::Kind::kUnary: 164 // No assertion on y can be made, as the branching paths involve both 165 // a unary (`mapSet`) and binary (`disjSet`) pathway. 166 assert(x != detail::kInvalidId && !v && o); 167 children.e0 = x; 168 children.e1 = y; 169 return; 170 // Binary operations. 171 case TensorExp::Kind::kMulF: 172 case TensorExp::Kind::kMulC: 173 case TensorExp::Kind::kMulI: 174 case TensorExp::Kind::kDivF: 175 case TensorExp::Kind::kDivC: 176 case TensorExp::Kind::kDivS: 177 case TensorExp::Kind::kDivU: 178 case TensorExp::Kind::kAddF: 179 case TensorExp::Kind::kAddC: 180 case TensorExp::Kind::kAddI: 181 case TensorExp::Kind::kSubF: 182 case TensorExp::Kind::kSubC: 183 case TensorExp::Kind::kSubI: 184 case TensorExp::Kind::kAndI: 185 case TensorExp::Kind::kOrI: 186 case TensorExp::Kind::kXorI: 187 case TensorExp::Kind::kShrS: 188 case TensorExp::Kind::kShrU: 189 case TensorExp::Kind::kShlI: 190 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); 191 children.e0 = x; 192 children.e1 = y; 193 return; 194 case TensorExp::Kind::kBinary: 195 case TensorExp::Kind::kReduce: 196 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o); 197 children.e0 = x; 198 children.e1 = y; 199 return; 200 } 201 llvm_unreachable("unexpected kind"); 202 } 203 204 Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops, 205 unsigned numFilterLoops, unsigned maxLvlRank) 206 : outTensor(numInputOutputTensors - 1), 207 syntheticTensor(numInputOutputTensors), 208 numTensors(numInputOutputTensors + 1), numNativeLoops(numNativeLoops), 209 numLoops(numNativeLoops + numFilterLoops), hasSparseOut(false), 210 lvlTypes(numTensors, 211 std::vector<DimLevelType>(numLoops, DimLevelType::Undef)), 212 loopToLvl(numTensors, 213 std::vector<std::optional<Level>>(numLoops, std::nullopt)), 214 lvlToLoop(numTensors, 215 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)), 216 loopToDependencies( 217 numLoops, std::vector<std::optional<std::pair<Level, DimLevelType>>>( 218 numTensors, std::nullopt)), 219 levelToDependentLoop(numTensors, std::vector<std::vector<LoopId>>( 220 maxLvlRank, std::vector<LoopId>())), 221 loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {} 222 223 //===----------------------------------------------------------------------===// 224 // Lattice methods. 225 //===----------------------------------------------------------------------===// 226 227 ExprId Merger::addTensorExp(TensorId t) { 228 assert(isValidTensorId(t)); 229 const ExprId eNew(tensorExps.size()); 230 tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId, 231 Value(), nullptr); 232 return eNew; 233 } 234 235 ExprId Merger::addLoopVarExp(LoopId i) { 236 assert(isValidLoopId(i)); 237 const ExprId eNew(tensorExps.size()); 238 tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId, 239 Value(), nullptr); 240 return eNew; 241 } 242 243 ExprId Merger::addInvariantExp(Value v) { 244 const ExprId eNew(tensorExps.size()); 245 tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId, 246 detail::kInvalidId, v, nullptr); 247 return eNew; 248 } 249 250 ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op) { 251 assert(k > TensorExp::Kind::kLoopVar); 252 const ExprId eNew(tensorExps.size()); 253 tensorExps.emplace_back(k, e0, e1, Value(), op); 254 return eNew; 255 } 256 257 ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op) { 258 assert(k > TensorExp::Kind::kLoopVar); 259 const ExprId eNew(tensorExps.size()); 260 tensorExps.emplace_back(k, e, detail::kInvalidId, v, op); 261 return eNew; 262 } 263 264 LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) { 265 const LatPointId pNew(latPoints.size()); 266 const unsigned size = numLoops * numTensors; 267 const TensorLoopId b = makeTensorLoopId(t, i); 268 latPoints.emplace_back(size, e); 269 latPoints[pNew].bits.set(b); 270 return pNew; 271 } 272 273 LatPointId Merger::addLat(const BitVector &bits, ExprId e) { 274 assert(bits.size() == numLoops * numTensors); 275 const LatPointId pNew(latPoints.size()); 276 latPoints.emplace_back(bits, e); 277 return pNew; 278 } 279 280 LatSetId Merger::addSet() { 281 const LatSetId sNew(latSets.size()); 282 latSets.emplace_back(); 283 return sNew; 284 } 285 286 LatPointId Merger::conjLat(TensorExp::Kind kind, LatPointId p0, LatPointId p1, 287 Operation *op) { 288 const LatPointId pNew(latPoints.size()); 289 const auto &point0 = lat(p0); 290 const auto &point1 = lat(p1); 291 BitVector bits(point0.bits); 292 bits |= point1.bits; 293 const ExprId e = addExp(kind, point0.exp, point1.exp, op); 294 latPoints.emplace_back(bits, e); 295 return pNew; 296 } 297 298 LatSetId Merger::conjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, 299 Operation *op) { 300 const LatSetId sNew = addSet(); 301 auto &setNew = latSets[sNew]; 302 for (const LatPointId p0 : set(s0)) 303 for (const LatPointId p1 : set(s1)) 304 setNew.push_back(conjLat(kind, p0, p1, op)); 305 return sNew; 306 } 307 308 LatSetId Merger::disjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, 309 Operation *op) { 310 const LatSetId sNew = conjSet(kind, s0, s1, op); 311 // Followed by all in s0. 312 latSets[sNew].append(latSets[s0]); 313 // Map binary 0-y to unary -y. 314 // TODO: move this if-else logic into buildLattices 315 if (kind == TensorExp::Kind::kSubF) 316 s1 = mapSet(TensorExp::Kind::kNegF, s1); 317 else if (kind == TensorExp::Kind::kSubC) 318 s1 = mapSet(TensorExp::Kind::kNegC, s1); 319 else if (kind == TensorExp::Kind::kSubI) 320 s1 = mapSet(TensorExp::Kind::kNegI, s1); 321 // Followed by all in s1. 322 latSets[sNew].append(latSets[s1]); 323 return sNew; 324 } 325 326 LatSetId Merger::combiSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, 327 Operation *orig, bool includeLeft, 328 TensorExp::Kind ltrans, Operation *opleft, 329 bool includeRight, TensorExp::Kind rtrans, 330 Operation *opright) { 331 const LatSetId sNew = conjSet(kind, s0, s1, orig); 332 // Left Region. 333 if (includeLeft) { 334 if (opleft) 335 s0 = mapSet(ltrans, s0, Value(), opleft); 336 latSets[sNew].append(latSets[s0]); 337 } 338 // Right Region. 339 if (includeRight) { 340 if (opright) 341 s1 = mapSet(rtrans, s1, Value(), opright); 342 latSets[sNew].append(latSets[s1]); 343 } 344 return sNew; 345 } 346 347 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v, 348 Operation *op) { 349 assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect); 350 const LatSetId sNew = addSet(); 351 auto &setNew = latSets[sNew]; 352 for (const LatPointId p : set(s0)) { 353 const auto &point = latPoints[p]; 354 setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op))); 355 } 356 return sNew; 357 } 358 359 LatSetId Merger::optimizeSet(LatSetId s0) { 360 const LatSetId sNew = addSet(); 361 auto &setNew = latSets[sNew]; 362 const auto &set0 = set(s0); 363 assert(!set0.empty()); 364 const LatPointId p0 = set0[0]; 365 for (const LatPointId p1 : set0) { 366 bool add = true; 367 if (p0 != p1) { 368 // Check whether this is a straightforward copy. 369 if (expIsTensor(latPoints[p1].exp, outTensor)) 370 continue; 371 // Check whether this conjunction is already covered. 372 for (const LatPointId p2 : setNew) { 373 assert(!latGT(p1, p2)); // Lj => Li would be bad 374 if (onlyDenseDiff(p2, p1)) { 375 add = false; 376 break; 377 } 378 } 379 assert(!add || latGT(p0, p1)); 380 } 381 if (add) 382 setNew.push_back(p1); 383 } 384 for (const LatPointId p : setNew) 385 latPoints[p].simple = simplifyCond(sNew, p); 386 return sNew; 387 } 388 389 BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { 390 // First determine if this lattice point is a *singleton*, i.e., 391 // the last point in a lattice, no other is less than this one. 392 bool isSingleton = true; 393 for (const LatPointId p1 : set(s0)) { 394 if (p0 != p1 && latGT(p0, p1)) { 395 isSingleton = false; 396 break; 397 } 398 } 399 400 BitVector simple(latPoints[p0].bits); 401 bool reset = isSingleton && hasAnySparse(simple); 402 const TensorLoopId be = simple.size(); 403 TensorLoopId offset = 0; // relative to the end 404 if (!reset) 405 // Starts resetting from a dense level, so that the first bit (if kept) 406 // is not undefined level-type. 407 for (unsigned b = 0; b < be; b++) { 408 if (simple[b] && isDenseDLT(getDimLevelType(TensorLoopId{b}))) { 409 offset = be - b - 1; // relative to the end 410 break; 411 } 412 } 413 414 // Now apply the two basic rules. We also iterate the bits reversely to always 415 // keep the rightmost bit (which could possibly be a synthetic tensor). 416 for (unsigned b = be - 1 - offset, i = 0; i < be; 417 b = b == 0 ? be - 1 : b - 1, i++) { 418 // Slice on dense level has `locate` property as well, and can be optimized. 419 if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) { 420 const auto dlt = getDimLevelType(b); 421 if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt)) { 422 if (reset) 423 simple.reset(b); 424 reset = true; 425 } 426 } 427 } 428 return simple; 429 } 430 431 bool Merger::latGT(LatPointId i, LatPointId j) const { 432 const BitVector &bitsi = lat(i).bits; 433 const BitVector &bitsj = lat(j).bits; 434 assert(bitsi.size() == bitsj.size()); 435 if (bitsi.count() > bitsj.count()) { 436 for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) 437 if (bitsj[b] && !bitsi[b]) 438 return false; 439 return true; 440 } 441 return false; 442 } 443 444 bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { 445 BitVector tmp(latPoints[j].bits); 446 tmp ^= latPoints[i].bits; 447 return !hasAnySparse(tmp); 448 } 449 450 bool Merger::expContainsTensor(ExprId e, TensorId t) const { 451 const auto &expr = exp(e); 452 // First we check `expIsTensor`. 453 if (expr.kind == TensorExp::Kind::kTensor) 454 return expr.tensor == t; 455 456 switch (getExpArity(expr.kind)) { 457 case ExpArity::kNullary: 458 return false; 459 case ExpArity::kUnary: { 460 const ExprId e0 = expr.children.e0; 461 return expContainsTensor(e0, t); 462 } 463 case ExpArity::kBinary: { 464 const ExprId e0 = expr.children.e0; 465 const ExprId e1 = expr.children.e1; 466 return expContainsTensor(e0, t) || expContainsTensor(e1, t); 467 } 468 } 469 llvm_unreachable("unexpected arity"); 470 } 471 472 bool Merger::hasNegateOnOut(ExprId e) const { 473 const auto &expr = exp(e); 474 switch (expr.kind) { 475 case TensorExp::Kind::kNegF: 476 case TensorExp::Kind::kNegC: 477 case TensorExp::Kind::kNegI: 478 return expContainsTensor(expr.children.e0, outTensor); 479 case TensorExp::Kind::kSubF: 480 case TensorExp::Kind::kSubC: 481 case TensorExp::Kind::kSubI: 482 return expContainsTensor(expr.children.e1, outTensor) || 483 hasNegateOnOut(expr.children.e0); 484 default: { 485 switch (getExpArity(expr.kind)) { 486 case ExpArity::kNullary: 487 return false; 488 case ExpArity::kUnary: 489 return hasNegateOnOut(expr.children.e0); 490 case ExpArity::kBinary: 491 return hasNegateOnOut(expr.children.e0) || 492 hasNegateOnOut(expr.children.e1); 493 } 494 } 495 } 496 llvm_unreachable("unexpected kind"); 497 } 498 499 bool Merger::isSingleCondition(TensorId t, ExprId e) const { 500 assert(isValidTensorId(t)); 501 const auto &expr = exp(e); 502 switch (expr.kind) { 503 // Leaf. 504 case TensorExp::Kind::kTensor: 505 return expr.tensor == t; 506 case TensorExp::Kind::kInvariant: 507 case TensorExp::Kind::kLoopVar: 508 return false; 509 // Unary operations. 510 case TensorExp::Kind::kAbsF: 511 case TensorExp::Kind::kAbsC: 512 case TensorExp::Kind::kAbsI: 513 case TensorExp::Kind::kCeilF: 514 case TensorExp::Kind::kFloorF: 515 case TensorExp::Kind::kSqrtF: 516 case TensorExp::Kind::kSqrtC: 517 case TensorExp::Kind::kExpm1F: 518 case TensorExp::Kind::kExpm1C: 519 case TensorExp::Kind::kLog1pF: 520 case TensorExp::Kind::kLog1pC: 521 case TensorExp::Kind::kSinF: 522 case TensorExp::Kind::kSinC: 523 case TensorExp::Kind::kTanhF: 524 case TensorExp::Kind::kTanhC: 525 case TensorExp::Kind::kNegF: 526 case TensorExp::Kind::kNegC: 527 case TensorExp::Kind::kNegI: 528 case TensorExp::Kind::kTruncF: 529 case TensorExp::Kind::kExtF: 530 case TensorExp::Kind::kCastFS: 531 case TensorExp::Kind::kCastFU: 532 case TensorExp::Kind::kCastSF: 533 case TensorExp::Kind::kCastUF: 534 case TensorExp::Kind::kCastS: 535 case TensorExp::Kind::kCastU: 536 case TensorExp::Kind::kCastIdx: 537 case TensorExp::Kind::kTruncI: 538 case TensorExp::Kind::kCIm: 539 case TensorExp::Kind::kCRe: 540 case TensorExp::Kind::kBitCast: 541 return isSingleCondition(t, expr.children.e0); 542 case TensorExp::Kind::kBinaryBranch: 543 case TensorExp::Kind::kUnary: 544 case TensorExp::Kind::kSelect: 545 return false; 546 // Binary operations. 547 case TensorExp::Kind::kDivF: // note: x / c only 548 case TensorExp::Kind::kDivC: 549 case TensorExp::Kind::kDivS: 550 case TensorExp::Kind::kDivU: 551 assert(!maybeZero(expr.children.e1)); 552 return isSingleCondition(t, expr.children.e0); 553 case TensorExp::Kind::kShrS: // note: x >> inv only 554 case TensorExp::Kind::kShrU: 555 case TensorExp::Kind::kShlI: 556 assert(isInvariant(expr.children.e1)); 557 return isSingleCondition(t, expr.children.e0); 558 case TensorExp::Kind::kMulF: 559 case TensorExp::Kind::kMulC: 560 case TensorExp::Kind::kMulI: 561 case TensorExp::Kind::kAndI: 562 if (isSingleCondition(t, expr.children.e0)) 563 return isSingleCondition(t, expr.children.e1) || 564 isInvariant(expr.children.e1); 565 if (isSingleCondition(t, expr.children.e1)) 566 return isInvariant(expr.children.e0); 567 return false; 568 case TensorExp::Kind::kAddF: 569 case TensorExp::Kind::kAddC: 570 case TensorExp::Kind::kAddI: 571 return isSingleCondition(t, expr.children.e0) && 572 isSingleCondition(t, expr.children.e1); 573 case TensorExp::Kind::kSubF: 574 case TensorExp::Kind::kSubC: 575 case TensorExp::Kind::kSubI: 576 case TensorExp::Kind::kOrI: 577 case TensorExp::Kind::kXorI: 578 case TensorExp::Kind::kBinary: 579 case TensorExp::Kind::kReduce: 580 return false; 581 } 582 llvm_unreachable("unexpected kind"); 583 } 584 585 bool Merger::hasAnySparse(const BitVector &bits) const { 586 for (TensorLoopId b : bits.set_bits()) { 587 const auto dlt = getDimLevelType(b); 588 if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) 589 return true; 590 } 591 return hasSparseIdxReduction(bits); 592 } 593 594 bool Merger::hasSparseIdxReduction(const BitVector &bits) const { 595 for (TensorLoopId b : bits.set_bits()) 596 if (isSparseLvlWithNonTrivialIdxExp(b)) 597 return true; 598 return false; 599 } 600 601 #ifndef NDEBUG 602 603 //===----------------------------------------------------------------------===// 604 // Print methods (for debugging). 605 //===----------------------------------------------------------------------===// 606 607 static const char *kindToOpSymbol(TensorExp::Kind kind) { 608 switch (kind) { 609 // Leaf. 610 case TensorExp::Kind::kTensor: 611 return "tensor"; 612 case TensorExp::Kind::kInvariant: 613 return "invariant"; 614 case TensorExp::Kind::kLoopVar: 615 return "index"; 616 // Unary operations. 617 case TensorExp::Kind::kAbsF: 618 case TensorExp::Kind::kAbsC: 619 case TensorExp::Kind::kAbsI: 620 return "abs"; 621 case TensorExp::Kind::kCeilF: 622 return "ceil"; 623 case TensorExp::Kind::kFloorF: 624 return "floor"; 625 case TensorExp::Kind::kSqrtF: 626 case TensorExp::Kind::kSqrtC: 627 return "sqrt"; 628 case TensorExp::Kind::kExpm1F: 629 case TensorExp::Kind::kExpm1C: 630 return "expm1"; 631 case TensorExp::Kind::kLog1pF: 632 case TensorExp::Kind::kLog1pC: 633 return "log1p"; 634 case TensorExp::Kind::kSinF: 635 case TensorExp::Kind::kSinC: 636 return "sin"; 637 case TensorExp::Kind::kTanhF: 638 case TensorExp::Kind::kTanhC: 639 return "tanh"; 640 case TensorExp::Kind::kNegF: 641 case TensorExp::Kind::kNegC: 642 case TensorExp::Kind::kNegI: 643 return "-"; 644 case TensorExp::Kind::kTruncF: 645 case TensorExp::Kind::kExtF: 646 case TensorExp::Kind::kCastFS: 647 case TensorExp::Kind::kCastFU: 648 case TensorExp::Kind::kCastSF: 649 case TensorExp::Kind::kCastUF: 650 case TensorExp::Kind::kCastS: 651 case TensorExp::Kind::kCastU: 652 case TensorExp::Kind::kCastIdx: 653 case TensorExp::Kind::kTruncI: 654 case TensorExp::Kind::kCIm: 655 return "complex.im"; 656 case TensorExp::Kind::kCRe: 657 return "complex.re"; 658 case TensorExp::Kind::kBitCast: 659 return "cast"; 660 case TensorExp::Kind::kBinaryBranch: 661 return "binary_branch"; 662 case TensorExp::Kind::kUnary: 663 return "unary"; 664 case TensorExp::Kind::kSelect: 665 return "select"; 666 // Binary operations. 667 case TensorExp::Kind::kMulF: 668 case TensorExp::Kind::kMulC: 669 case TensorExp::Kind::kMulI: 670 return "*"; 671 case TensorExp::Kind::kDivF: 672 case TensorExp::Kind::kDivC: 673 case TensorExp::Kind::kDivS: 674 case TensorExp::Kind::kDivU: 675 return "/"; 676 case TensorExp::Kind::kAddF: 677 case TensorExp::Kind::kAddC: 678 case TensorExp::Kind::kAddI: 679 return "+"; 680 case TensorExp::Kind::kSubF: 681 case TensorExp::Kind::kSubC: 682 case TensorExp::Kind::kSubI: 683 return "-"; 684 case TensorExp::Kind::kAndI: 685 return "&"; 686 case TensorExp::Kind::kOrI: 687 return "|"; 688 case TensorExp::Kind::kXorI: 689 return "^"; 690 case TensorExp::Kind::kShrS: 691 return "a>>"; 692 case TensorExp::Kind::kShrU: 693 return ">>"; 694 case TensorExp::Kind::kShlI: 695 return "<<"; 696 case TensorExp::Kind::kBinary: 697 return "binary"; 698 case TensorExp::Kind::kReduce: 699 return "reduce"; 700 } 701 llvm_unreachable("unexpected kind for symbol"); 702 } 703 704 void Merger::dumpExp(ExprId e) const { 705 const auto &expr = exp(e); 706 switch (expr.kind) { 707 // Leaf. 708 case TensorExp::Kind::kTensor: 709 if (expr.tensor == syntheticTensor) 710 llvm::dbgs() << "synthetic_"; 711 else if (expr.tensor == outTensor) 712 llvm::dbgs() << "output_"; 713 llvm::dbgs() << "tensor_" << expr.tensor; 714 break; 715 case TensorExp::Kind::kInvariant: 716 llvm::dbgs() << "invariant"; 717 break; 718 case TensorExp::Kind::kLoopVar: 719 llvm::dbgs() << "loopvar_" << expr.loop; 720 break; 721 // Unary operations. 722 case TensorExp::Kind::kAbsF: 723 case TensorExp::Kind::kAbsC: 724 case TensorExp::Kind::kAbsI: 725 case TensorExp::Kind::kCeilF: 726 case TensorExp::Kind::kFloorF: 727 case TensorExp::Kind::kSqrtF: 728 case TensorExp::Kind::kSqrtC: 729 case TensorExp::Kind::kExpm1F: 730 case TensorExp::Kind::kExpm1C: 731 case TensorExp::Kind::kLog1pF: 732 case TensorExp::Kind::kLog1pC: 733 case TensorExp::Kind::kSinF: 734 case TensorExp::Kind::kSinC: 735 case TensorExp::Kind::kTanhF: 736 case TensorExp::Kind::kTanhC: 737 case TensorExp::Kind::kNegF: 738 case TensorExp::Kind::kNegC: 739 case TensorExp::Kind::kNegI: 740 case TensorExp::Kind::kTruncF: 741 case TensorExp::Kind::kExtF: 742 case TensorExp::Kind::kCastFS: 743 case TensorExp::Kind::kCastFU: 744 case TensorExp::Kind::kCastSF: 745 case TensorExp::Kind::kCastUF: 746 case TensorExp::Kind::kCastS: 747 case TensorExp::Kind::kCastU: 748 case TensorExp::Kind::kCastIdx: 749 case TensorExp::Kind::kTruncI: 750 case TensorExp::Kind::kCIm: 751 case TensorExp::Kind::kCRe: 752 case TensorExp::Kind::kBitCast: 753 case TensorExp::Kind::kBinaryBranch: 754 case TensorExp::Kind::kUnary: 755 case TensorExp::Kind::kSelect: 756 llvm::dbgs() << kindToOpSymbol(expr.kind) << " "; 757 dumpExp(expr.children.e0); 758 break; 759 // Binary operations. 760 case TensorExp::Kind::kMulF: 761 case TensorExp::Kind::kMulC: 762 case TensorExp::Kind::kMulI: 763 case TensorExp::Kind::kDivF: 764 case TensorExp::Kind::kDivC: 765 case TensorExp::Kind::kDivS: 766 case TensorExp::Kind::kDivU: 767 case TensorExp::Kind::kAddF: 768 case TensorExp::Kind::kAddC: 769 case TensorExp::Kind::kAddI: 770 case TensorExp::Kind::kSubF: 771 case TensorExp::Kind::kSubC: 772 case TensorExp::Kind::kSubI: 773 case TensorExp::Kind::kAndI: 774 case TensorExp::Kind::kOrI: 775 case TensorExp::Kind::kXorI: 776 case TensorExp::Kind::kShrS: 777 case TensorExp::Kind::kShrU: 778 case TensorExp::Kind::kShlI: 779 case TensorExp::Kind::kBinary: 780 case TensorExp::Kind::kReduce: 781 llvm::dbgs() << "("; 782 dumpExp(expr.children.e0); 783 llvm::dbgs() << " " << kindToOpSymbol(expr.kind) << " "; 784 dumpExp(expr.children.e1); 785 llvm::dbgs() << ")"; 786 } 787 } 788 789 void Merger::dumpLat(LatPointId p) const { 790 const auto &point = lat(p); 791 llvm::dbgs() << "lat("; 792 dumpBits(point.bits); 793 llvm::dbgs() << " :"; 794 dumpBits(point.simple); 795 llvm::dbgs() << " : "; 796 dumpExp(point.exp); 797 llvm::dbgs() << " )\n"; 798 } 799 800 void Merger::dumpSet(LatSetId s) const { 801 const auto &ss = set(s); 802 llvm::dbgs() << "{ #" << ss.size() << "\n"; 803 for (const LatPointId p : ss) { 804 llvm::dbgs() << " "; 805 dumpLat(p); 806 } 807 llvm::dbgs() << "}\n"; 808 } 809 810 void Merger::dumpBits(const BitVector &bits) const { 811 for (TensorLoopId b = 0, be = bits.size(); b < be; b++) { 812 if (bits[b]) { 813 const TensorId t = tensor(b); 814 const LoopId i = loop(b); 815 const auto dlt = lvlTypes[t][i]; 816 if (isLvlWithNonTrivialIdxExp(b)) 817 llvm::dbgs() << " DEP_" << t << "_" << i; 818 else 819 llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); 820 } 821 } 822 } 823 824 #endif // NDEBUG 825 826 //===----------------------------------------------------------------------===// 827 // Builder methods. 828 //===----------------------------------------------------------------------===// 829 830 LatSetId Merger::buildLattices(ExprId e, LoopId i) { 831 // NOTE: The `expr` reference will be invalidated by recursive calls 832 // (and any other method that may add new expressions); therefore, the 833 // code below must make sure to copy fields of `expr` into local variables 834 // before making any recursive calls. 835 const auto &expr = exp(e); 836 const TensorExp::Kind kind = expr.kind; 837 switch (kind) { 838 // Leaf. 839 case TensorExp::Kind::kTensor: 840 case TensorExp::Kind::kInvariant: 841 case TensorExp::Kind::kLoopVar: { 842 // Either the loop-var is really used in the tensor expression, or it is 843 // set to the undefined loop-var in that level. An invariant expression, 844 // a proper index value, and a truly dynamic sparse output tensor are set 845 // to a synthetic tensor with undefined indices only to ensure the 846 // iteration space is not skipped as a result of their contents. 847 const LatSetId s = addSet(); 848 TensorId t = syntheticTensor; 849 if (kind == TensorExp::Kind::kTensor) { 850 t = expr.tensor; 851 if (hasSparseOut && t == outTensor) 852 t = syntheticTensor; 853 } 854 latSets[s].push_back(addLat(t, i, e)); 855 return s; 856 } 857 // Unary operations. 858 case TensorExp::Kind::kAbsF: 859 case TensorExp::Kind::kAbsC: 860 case TensorExp::Kind::kAbsI: 861 case TensorExp::Kind::kCeilF: 862 case TensorExp::Kind::kFloorF: 863 case TensorExp::Kind::kSqrtF: 864 case TensorExp::Kind::kSqrtC: 865 case TensorExp::Kind::kExpm1F: 866 case TensorExp::Kind::kExpm1C: 867 case TensorExp::Kind::kLog1pF: 868 case TensorExp::Kind::kLog1pC: 869 case TensorExp::Kind::kSinF: 870 case TensorExp::Kind::kSinC: 871 case TensorExp::Kind::kTanhF: 872 case TensorExp::Kind::kTanhC: 873 case TensorExp::Kind::kNegF: 874 case TensorExp::Kind::kNegC: 875 case TensorExp::Kind::kNegI: 876 case TensorExp::Kind::kTruncF: 877 case TensorExp::Kind::kExtF: 878 case TensorExp::Kind::kCastFS: 879 case TensorExp::Kind::kCastFU: 880 case TensorExp::Kind::kCastSF: 881 case TensorExp::Kind::kCastUF: 882 case TensorExp::Kind::kCastS: 883 case TensorExp::Kind::kCastU: 884 case TensorExp::Kind::kCastIdx: 885 case TensorExp::Kind::kTruncI: 886 case TensorExp::Kind::kCIm: 887 case TensorExp::Kind::kCRe: 888 case TensorExp::Kind::kBitCast: 889 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 890 // lattice set of the operand through the operator into a new set. 891 // 892 // -y|!y | y | 893 // --+---+---+ 894 // | 0 |-y | 895 { 896 const ExprId e0 = expr.children.e0; 897 const Value v = expr.val; 898 return mapSet(kind, buildLattices(e0, i), v); 899 } 900 case TensorExp::Kind::kBinaryBranch: 901 case TensorExp::Kind::kSelect: 902 // The left or right half of a binary operation which has already 903 // been split into separate operations for each region. 904 { 905 const ExprId e0 = expr.children.e0; 906 Operation *const op = expr.op; 907 return mapSet(kind, buildLattices(e0, i), Value(), op); 908 } 909 case TensorExp::Kind::kUnary: 910 // A custom unary operation. 911 // 912 // op y| !y | y | 913 // ----+----------+------------+ 914 // | absent() | present(y) | 915 { 916 const ExprId e0 = expr.children.e0; 917 UnaryOp unop = cast<UnaryOp>(expr.op); 918 const LatSetId child0 = buildLattices(e0, i); 919 Region &absentRegion = unop.getAbsentRegion(); 920 921 if (absentRegion.empty()) { 922 // Simple mapping over existing values. 923 return mapSet(kind, child0, Value(), unop); 924 } // Use a disjunction with `unop` on the left and the absent value as an 925 // invariant on the right. 926 Block &absentBlock = absentRegion.front(); 927 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); 928 const Value absentVal = absentYield.getResult(); 929 const ExprId rhs = addInvariantExp(absentVal); 930 return disjSet(kind, child0, buildLattices(rhs, i), unop); 931 } 932 // Binary operations. 933 case TensorExp::Kind::kMulF: 934 case TensorExp::Kind::kMulC: 935 case TensorExp::Kind::kMulI: 936 case TensorExp::Kind::kAndI: 937 // A multiplicative operation only needs to be performed 938 // for the conjunction of sparse iteration spaces. 939 // 940 // x*y|!y | y | 941 // ---+---+---+ 942 // !x | 0 | 0 | 943 // x | 0 |x*y| 944 // 945 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. 946 { 947 const ExprId e0 = expr.children.e0; 948 const ExprId e1 = expr.children.e1; 949 return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i)); 950 } 951 case TensorExp::Kind::kDivF: 952 case TensorExp::Kind::kDivC: 953 case TensorExp::Kind::kDivS: 954 case TensorExp::Kind::kDivU: 955 // A division is tricky, since 0/0, 0/c, c/0 all have 956 // specific outcomes for floating-point and integers. 957 // Thus, we need to traverse the full iteration space. 958 // 959 // x/y|!y | y | 960 // ---+---+---+ 961 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 962 // x |x/0|x/y| INT: x/0=exception for any x 963 // 964 // TODO: for now we "fixed" this by only accepting x/c cases 965 // during expression building, so that the conjunction 966 // rules applies (viz. x/c = x*(1/c) as far as lattice 967 // construction is concerned). 968 { 969 const ExprId e0 = expr.children.e0; 970 const ExprId e1 = expr.children.e1; 971 assert(!maybeZero(e1)); 972 return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i)); 973 } 974 case TensorExp::Kind::kAddF: 975 case TensorExp::Kind::kAddC: 976 case TensorExp::Kind::kAddI: 977 case TensorExp::Kind::kSubF: 978 case TensorExp::Kind::kSubC: 979 case TensorExp::Kind::kSubI: 980 case TensorExp::Kind::kOrI: 981 case TensorExp::Kind::kXorI: 982 // An additive operation needs to be performed 983 // for the disjunction of sparse iteration spaces. 984 // 985 // x+y|!y | y | x-y|!y | y | 986 // ---+---+---+ ---+---+---+ 987 // !x | 0 | y | !x | 0 |-y | 988 // x | x |x+y| x | x |x-y| 989 { 990 const ExprId e0 = expr.children.e0; 991 const ExprId e1 = expr.children.e1; 992 return disjSet(kind, buildLattices(e0, i), buildLattices(e1, i)); 993 } 994 case TensorExp::Kind::kShrS: 995 case TensorExp::Kind::kShrU: 996 case TensorExp::Kind::kShlI: 997 // A shift operation by an invariant amount (viz. tensor expressions 998 // can only occur at the left-hand-side of the operator) can be handled 999 // with the conjuction rule. 1000 { 1001 const ExprId e0 = expr.children.e0; 1002 const ExprId e1 = expr.children.e1; 1003 assert(isInvariant(e1)); 1004 return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i)); 1005 } 1006 case TensorExp::Kind::kBinary: 1007 // A custom binary operation. 1008 // 1009 // x op y| !y | y | 1010 // ------+---------+--------------+ 1011 // !x | empty | right(y) | 1012 // x | left(x) | overlap(x,y) | 1013 { 1014 const ExprId e0 = expr.children.e0; 1015 const ExprId e1 = expr.children.e1; 1016 BinaryOp binop = cast<BinaryOp>(expr.op); 1017 const LatSetId child0 = buildLattices(e0, i); 1018 const LatSetId child1 = buildLattices(e1, i); 1019 Region &leftRegion = binop.getLeftRegion(); 1020 Region &rightRegion = binop.getRightRegion(); 1021 // Left Region. 1022 Operation *leftYield = nullptr; 1023 if (!leftRegion.empty()) { 1024 Block &leftBlock = leftRegion.front(); 1025 leftYield = leftBlock.getTerminator(); 1026 } 1027 // Right Region. 1028 Operation *rightYield = nullptr; 1029 if (!rightRegion.empty()) { 1030 Block &rightBlock = rightRegion.front(); 1031 rightYield = rightBlock.getTerminator(); 1032 } 1033 bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty(); 1034 bool includeRight = binop.getRightIdentity() || !rightRegion.empty(); 1035 return combiSet(TensorExp::Kind::kBinary, child0, child1, binop, 1036 includeLeft, TensorExp::Kind::kBinaryBranch, leftYield, 1037 includeRight, TensorExp::Kind::kBinaryBranch, rightYield); 1038 } 1039 case TensorExp::Kind::kReduce: 1040 // A custom reduce operation. 1041 { 1042 const ExprId e0 = expr.children.e0; 1043 const ExprId e1 = expr.children.e1; 1044 Operation *const op = expr.op; 1045 return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i), op); 1046 } 1047 } 1048 llvm_unreachable("unexpected expression kind"); 1049 } 1050 1051 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 1052 // Build the linalg semantics backward from yield. 1053 Operation *yield = op.getRegion().front().getTerminator(); 1054 assert(isa<linalg::YieldOp>(yield)); 1055 return buildTensorExp(op, yield->getOperand(0)); 1056 } 1057 1058 /// Only returns false if we are certain this is a nonzero. 1059 bool Merger::maybeZero(ExprId e) const { 1060 const auto &expr = exp(e); 1061 if (expr.kind == TensorExp::Kind::kInvariant) { 1062 if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) { 1063 ArrayAttr arrayAttr = c.getValue(); 1064 return arrayAttr[0].cast<FloatAttr>().getValue().isZero() && 1065 arrayAttr[1].cast<FloatAttr>().getValue().isZero(); 1066 } 1067 if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>()) 1068 return c.value() == 0; 1069 if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>()) 1070 return c.value().isZero(); 1071 } 1072 return true; 1073 } 1074 1075 Type Merger::inferType(ExprId e, Value src) const { 1076 // Obtain the destination type from the cast node. 1077 Type dtp = exp(e).val.getType(); 1078 // Inspect source type. For vector types, apply the same 1079 // vectorization to the destination type. 1080 if (auto vtp = src.getType().dyn_cast<VectorType>()) 1081 return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); 1082 return dtp; 1083 } 1084 1085 /// Ensures that sparse compiler can generate code for expression. 1086 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) { 1087 // Arguments are always admissible. 1088 if (v.isa<BlockArgument>()) 1089 return true; 1090 // Accept index anywhere. 1091 Operation *def = v.getDefiningOp(); 1092 if (isa<linalg::IndexOp>(def)) 1093 return true; 1094 // Operation defined outside branch. 1095 if (def->getBlock() != block) 1096 return def->getBlock() != op->getBlock(); // invariant? 1097 // Operation defined within branch. Anything is accepted, 1098 // as long as all subexpressions are admissible. 1099 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) 1100 if (!isAdmissibleBranchExp(op, block, def->getOperand(i))) 1101 return false; 1102 return true; 1103 } 1104 1105 /// Ensures that sparse compiler can generate code for branch. 1106 static bool isAdmissibleBranch(Operation *op, Region ®ion) { 1107 if (region.empty()) 1108 return true; 1109 // Build the semi-ring branch semantics backward from yield. 1110 Operation *yield = region.front().getTerminator(); 1111 assert(isa<YieldOp>(yield)); 1112 return isAdmissibleBranchExp(op, ®ion.front(), yield->getOperand(0)); 1113 } 1114 1115 std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 1116 if (auto arg = v.dyn_cast<BlockArgument>()) { 1117 const TensorId tid = makeTensorId(arg.getArgNumber()); 1118 // Any argument of the generic op that is not marked as a scalar 1119 // argument is considered a tensor, indexed by the implicit loop 1120 // bounds. This includes rank-0 tensor arguments. 1121 if (arg.getOwner()->getParentOp() == op) { 1122 OpOperand &t = op->getOpOperand(tid); 1123 if (!op.isScalar(&t)) 1124 return addTensorExp(tid); 1125 v = t.get(); // get scalar value 1126 } 1127 // Any other argument (marked as scalar argument for the generic op 1128 // or belonging to an enveloping op) is considered invariant. 1129 return addInvariantExp(v); 1130 } 1131 // Something defined outside is invariant. 1132 Operation *def = v.getDefiningOp(); 1133 if (def->getBlock() != &op.getRegion().front()) 1134 return addInvariantExp(v); 1135 // Construct index operations. 1136 if (def->getNumOperands() == 0) { 1137 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 1138 return addLoopVarExp(makeLoopId(indexOp.getDim())); 1139 } 1140 // Construct unary operations if subexpression can be built. 1141 if (def->getNumOperands() == 1) { 1142 const auto x = buildTensorExp(op, def->getOperand(0)); 1143 if (x.has_value()) { 1144 const ExprId e = *x; 1145 if (isa<math::AbsFOp>(def)) 1146 return addExp(TensorExp::Kind::kAbsF, e); 1147 if (isa<complex::AbsOp>(def)) 1148 return addExp(TensorExp::Kind::kAbsC, e); 1149 if (isa<math::AbsIOp>(def)) 1150 return addExp(TensorExp::Kind::kAbsI, e); 1151 if (isa<math::CeilOp>(def)) 1152 return addExp(TensorExp::Kind::kCeilF, e); 1153 if (isa<math::FloorOp>(def)) 1154 return addExp(TensorExp::Kind::kFloorF, e); 1155 if (isa<math::SqrtOp>(def)) 1156 return addExp(TensorExp::Kind::kSqrtF, e); 1157 if (isa<complex::SqrtOp>(def)) 1158 return addExp(TensorExp::Kind::kSqrtC, e); 1159 if (isa<math::ExpM1Op>(def)) 1160 return addExp(TensorExp::Kind::kExpm1F, e); 1161 if (isa<complex::Expm1Op>(def)) 1162 return addExp(TensorExp::Kind::kExpm1C, e); 1163 if (isa<math::Log1pOp>(def)) 1164 return addExp(TensorExp::Kind::kLog1pF, e); 1165 if (isa<complex::Log1pOp>(def)) 1166 return addExp(TensorExp::Kind::kLog1pC, e); 1167 if (isa<math::SinOp>(def)) 1168 return addExp(TensorExp::Kind::kSinF, e); 1169 if (isa<complex::SinOp>(def)) 1170 return addExp(TensorExp::Kind::kSinC, e); 1171 if (isa<math::TanhOp>(def)) 1172 return addExp(TensorExp::Kind::kTanhF, e); 1173 if (isa<complex::TanhOp>(def)) 1174 return addExp(TensorExp::Kind::kTanhC, e); 1175 if (isa<arith::NegFOp>(def)) 1176 return addExp(TensorExp::Kind::kNegF, e); // no negi in std 1177 if (isa<complex::NegOp>(def)) 1178 return addExp(TensorExp::Kind::kNegC, e); 1179 if (isa<arith::TruncFOp>(def)) 1180 return addExp(TensorExp::Kind::kTruncF, e, v); 1181 if (isa<arith::ExtFOp>(def)) 1182 return addExp(TensorExp::Kind::kExtF, e, v); 1183 if (isa<arith::FPToSIOp>(def)) 1184 return addExp(TensorExp::Kind::kCastFS, e, v); 1185 if (isa<arith::FPToUIOp>(def)) 1186 return addExp(TensorExp::Kind::kCastFU, e, v); 1187 if (isa<arith::SIToFPOp>(def)) 1188 return addExp(TensorExp::Kind::kCastSF, e, v); 1189 if (isa<arith::UIToFPOp>(def)) 1190 return addExp(TensorExp::Kind::kCastUF, e, v); 1191 if (isa<arith::ExtSIOp>(def)) 1192 return addExp(TensorExp::Kind::kCastS, e, v); 1193 if (isa<arith::ExtUIOp>(def)) 1194 return addExp(TensorExp::Kind::kCastU, e, v); 1195 if (isa<arith::IndexCastOp>(def)) 1196 return addExp(TensorExp::Kind::kCastIdx, e, v); 1197 if (isa<arith::TruncIOp>(def)) 1198 return addExp(TensorExp::Kind::kTruncI, e, v); 1199 if (isa<complex::ImOp>(def)) 1200 return addExp(TensorExp::Kind::kCIm, e); 1201 if (isa<complex::ReOp>(def)) 1202 return addExp(TensorExp::Kind::kCRe, e); 1203 if (isa<arith::BitcastOp>(def)) 1204 return addExp(TensorExp::Kind::kBitCast, e, v); 1205 if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) { 1206 if (isAdmissibleBranch(unop, unop.getPresentRegion()) && 1207 isAdmissibleBranch(unop, unop.getAbsentRegion())) 1208 return addExp(TensorExp::Kind::kUnary, e, Value(), def); 1209 } 1210 if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) { 1211 if (isAdmissibleBranch(selop, selop.getRegion())) 1212 return addExp(TensorExp::Kind::kSelect, e, Value(), def); 1213 } 1214 } 1215 } 1216 // Construct binary operations if subexpressions can be built. 1217 // See buildLattices() for an explanation of rejecting certain 1218 // division and shift operations. 1219 if (def->getNumOperands() == 2) { 1220 const auto x = buildTensorExp(op, def->getOperand(0)); 1221 const auto y = buildTensorExp(op, def->getOperand(1)); 1222 if (x.has_value() && y.has_value()) { 1223 const ExprId e0 = *x; 1224 const ExprId e1 = *y; 1225 if (isa<arith::MulFOp>(def)) 1226 return addExp(TensorExp::Kind::kMulF, e0, e1); 1227 if (isa<complex::MulOp>(def)) 1228 return addExp(TensorExp::Kind::kMulC, e0, e1); 1229 if (isa<arith::MulIOp>(def)) 1230 return addExp(TensorExp::Kind::kMulI, e0, e1); 1231 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 1232 return addExp(TensorExp::Kind::kDivF, e0, e1); 1233 if (isa<complex::DivOp>(def) && !maybeZero(e1)) 1234 return addExp(TensorExp::Kind::kDivC, e0, e1); 1235 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 1236 return addExp(TensorExp::Kind::kDivS, e0, e1); 1237 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 1238 return addExp(TensorExp::Kind::kDivU, e0, e1); 1239 if (isa<arith::AddFOp>(def)) 1240 return addExp(TensorExp::Kind::kAddF, e0, e1); 1241 if (isa<complex::AddOp>(def)) 1242 return addExp(TensorExp::Kind::kAddC, e0, e1); 1243 if (isa<arith::AddIOp>(def)) 1244 return addExp(TensorExp::Kind::kAddI, e0, e1); 1245 if (isa<arith::SubFOp>(def)) 1246 return addExp(TensorExp::Kind::kSubF, e0, e1); 1247 if (isa<complex::SubOp>(def)) 1248 return addExp(TensorExp::Kind::kSubC, e0, e1); 1249 if (isa<arith::SubIOp>(def)) 1250 return addExp(TensorExp::Kind::kSubI, e0, e1); 1251 if (isa<arith::AndIOp>(def)) 1252 return addExp(TensorExp::Kind::kAndI, e0, e1); 1253 if (isa<arith::OrIOp>(def)) 1254 return addExp(TensorExp::Kind::kOrI, e0, e1); 1255 if (isa<arith::XOrIOp>(def)) 1256 return addExp(TensorExp::Kind::kXorI, e0, e1); 1257 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 1258 return addExp(TensorExp::Kind::kShrS, e0, e1); 1259 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 1260 return addExp(TensorExp::Kind::kShrU, e0, e1); 1261 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 1262 return addExp(TensorExp::Kind::kShlI, e0, e1); 1263 if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) { 1264 if (isAdmissibleBranch(binop, binop.getOverlapRegion()) && 1265 (binop.getLeftIdentity() || 1266 isAdmissibleBranch(binop, binop.getLeftRegion())) && 1267 (binop.getRightIdentity() || 1268 isAdmissibleBranch(binop, binop.getRightRegion()))) 1269 return addExp(TensorExp::Kind::kBinary, e0, e1, def); 1270 } 1271 } 1272 } 1273 // Construct ternary operations if subexpressions can be built. 1274 if (def->getNumOperands() == 3) { 1275 const auto x = buildTensorExp(op, def->getOperand(0)); 1276 const auto y = buildTensorExp(op, def->getOperand(1)); 1277 const auto z = buildTensorExp(op, def->getOperand(2)); 1278 if (x.has_value() && y.has_value() && z.has_value()) { 1279 const ExprId e0 = *x; 1280 const ExprId e1 = *y; 1281 if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) { 1282 if (isAdmissibleBranch(redop, redop.getRegion())) 1283 return addExp(TensorExp::Kind::kReduce, e0, e1, def); 1284 } 1285 } 1286 } 1287 // Cannot build. 1288 return std::nullopt; 1289 } 1290 1291 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, 1292 ValueRange vals) { 1293 // Make a clone of overlap region. 1294 Region tmpRegion; 1295 IRMapping mapper; 1296 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); 1297 Block &clonedBlock = tmpRegion.front(); 1298 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); 1299 // Merge cloned block and return yield value. 1300 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1301 rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals); 1302 Value val = clonedYield.getResult(); 1303 rewriter.eraseOp(clonedYield); 1304 rewriter.eraseOp(placeholder); 1305 return val; 1306 } 1307 1308 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, 1309 Operation *op, Value v0) { 1310 if (!v0) 1311 // Empty input value must be propagated. 1312 return Value(); 1313 UnaryOp unop = cast<UnaryOp>(op); 1314 Region &presentRegion = unop.getPresentRegion(); 1315 if (presentRegion.empty()) 1316 // Uninitialized Value() will be interpreted as missing data in the 1317 // output. 1318 return Value(); 1319 return insertYieldOp(rewriter, loc, presentRegion, {v0}); 1320 } 1321 1322 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, 1323 Operation *op, Value v0, Value v1) { 1324 if (!v0 || !v1) 1325 // Empty input values must be propagated. 1326 return Value(); 1327 BinaryOp binop = cast<BinaryOp>(op); 1328 Region &overlapRegion = binop.getOverlapRegion(); 1329 if (overlapRegion.empty()) 1330 // Uninitialized Value() will be interpreted as missing data in the 1331 // output. 1332 return Value(); 1333 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); 1334 } 1335 1336 Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, 1337 Value v1) const { 1338 const auto &expr = exp(e); 1339 switch (expr.kind) { 1340 // Leaf. 1341 case TensorExp::Kind::kTensor: 1342 case TensorExp::Kind::kInvariant: 1343 case TensorExp::Kind::kLoopVar: 1344 llvm_unreachable("unexpected non-op"); 1345 // Unary operations. 1346 case TensorExp::Kind::kAbsF: 1347 return rewriter.create<math::AbsFOp>(loc, v0); 1348 case TensorExp::Kind::kAbsC: { 1349 auto type = v0.getType().cast<ComplexType>(); 1350 auto eltType = type.getElementType().cast<FloatType>(); 1351 return rewriter.create<complex::AbsOp>(loc, eltType, v0); 1352 } 1353 case TensorExp::Kind::kAbsI: 1354 return rewriter.create<math::AbsIOp>(loc, v0); 1355 case TensorExp::Kind::kCeilF: 1356 return rewriter.create<math::CeilOp>(loc, v0); 1357 case TensorExp::Kind::kFloorF: 1358 return rewriter.create<math::FloorOp>(loc, v0); 1359 case TensorExp::Kind::kSqrtF: 1360 return rewriter.create<math::SqrtOp>(loc, v0); 1361 case TensorExp::Kind::kSqrtC: 1362 return rewriter.create<complex::SqrtOp>(loc, v0); 1363 case TensorExp::Kind::kExpm1F: 1364 return rewriter.create<math::ExpM1Op>(loc, v0); 1365 case TensorExp::Kind::kExpm1C: 1366 return rewriter.create<complex::Expm1Op>(loc, v0); 1367 case TensorExp::Kind::kLog1pF: 1368 return rewriter.create<math::Log1pOp>(loc, v0); 1369 case TensorExp::Kind::kLog1pC: 1370 return rewriter.create<complex::Log1pOp>(loc, v0); 1371 case TensorExp::Kind::kSinF: 1372 return rewriter.create<math::SinOp>(loc, v0); 1373 case TensorExp::Kind::kSinC: 1374 return rewriter.create<complex::SinOp>(loc, v0); 1375 case TensorExp::Kind::kTanhF: 1376 return rewriter.create<math::TanhOp>(loc, v0); 1377 case TensorExp::Kind::kTanhC: 1378 return rewriter.create<complex::TanhOp>(loc, v0); 1379 case TensorExp::Kind::kNegF: 1380 return rewriter.create<arith::NegFOp>(loc, v0); 1381 case TensorExp::Kind::kNegC: 1382 return rewriter.create<complex::NegOp>(loc, v0); 1383 case TensorExp::Kind::kNegI: // no negi in std 1384 return rewriter.create<arith::SubIOp>( 1385 loc, 1386 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 1387 rewriter.getZeroAttr(v0.getType())), 1388 v0); 1389 case TensorExp::Kind::kTruncF: 1390 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 1391 case TensorExp::Kind::kExtF: 1392 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 1393 case TensorExp::Kind::kCastFS: 1394 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 1395 case TensorExp::Kind::kCastFU: 1396 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 1397 case TensorExp::Kind::kCastSF: 1398 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 1399 case TensorExp::Kind::kCastUF: 1400 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 1401 case TensorExp::Kind::kCastS: 1402 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 1403 case TensorExp::Kind::kCastU: 1404 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 1405 case TensorExp::Kind::kCastIdx: 1406 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 1407 case TensorExp::Kind::kTruncI: 1408 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 1409 case TensorExp::Kind::kCIm: { 1410 auto type = v0.getType().cast<ComplexType>(); 1411 auto eltType = type.getElementType().cast<FloatType>(); 1412 return rewriter.create<complex::ImOp>(loc, eltType, v0); 1413 } 1414 case TensorExp::Kind::kCRe: { 1415 auto type = v0.getType().cast<ComplexType>(); 1416 auto eltType = type.getElementType().cast<FloatType>(); 1417 return rewriter.create<complex::ReOp>(loc, eltType, v0); 1418 } 1419 case TensorExp::Kind::kBitCast: 1420 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 1421 // Binary operations. 1422 case TensorExp::Kind::kMulF: 1423 return rewriter.create<arith::MulFOp>(loc, v0, v1); 1424 case TensorExp::Kind::kMulC: 1425 return rewriter.create<complex::MulOp>(loc, v0, v1); 1426 case TensorExp::Kind::kMulI: 1427 return rewriter.create<arith::MulIOp>(loc, v0, v1); 1428 case TensorExp::Kind::kDivF: 1429 return rewriter.create<arith::DivFOp>(loc, v0, v1); 1430 case TensorExp::Kind::kDivC: 1431 return rewriter.create<complex::DivOp>(loc, v0, v1); 1432 case TensorExp::Kind::kDivS: 1433 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 1434 case TensorExp::Kind::kDivU: 1435 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 1436 case TensorExp::Kind::kAddF: 1437 return rewriter.create<arith::AddFOp>(loc, v0, v1); 1438 case TensorExp::Kind::kAddC: 1439 return rewriter.create<complex::AddOp>(loc, v0, v1); 1440 case TensorExp::Kind::kAddI: 1441 return rewriter.create<arith::AddIOp>(loc, v0, v1); 1442 case TensorExp::Kind::kSubF: 1443 return rewriter.create<arith::SubFOp>(loc, v0, v1); 1444 case TensorExp::Kind::kSubC: 1445 return rewriter.create<complex::SubOp>(loc, v0, v1); 1446 case TensorExp::Kind::kSubI: 1447 return rewriter.create<arith::SubIOp>(loc, v0, v1); 1448 case TensorExp::Kind::kAndI: 1449 return rewriter.create<arith::AndIOp>(loc, v0, v1); 1450 case TensorExp::Kind::kOrI: 1451 return rewriter.create<arith::OrIOp>(loc, v0, v1); 1452 case TensorExp::Kind::kXorI: 1453 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 1454 case TensorExp::Kind::kShrS: 1455 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 1456 case TensorExp::Kind::kShrU: 1457 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 1458 case TensorExp::Kind::kShlI: 1459 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 1460 case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic. 1461 return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(), 1462 {v0}); 1463 case TensorExp::Kind::kUnary: 1464 return buildUnaryPresent(rewriter, loc, expr.op, v0); 1465 case TensorExp::Kind::kSelect: 1466 return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(), 1467 {v0}); 1468 case TensorExp::Kind::kBinary: 1469 return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1); 1470 case TensorExp::Kind::kReduce: { 1471 ReduceOp redOp = cast<ReduceOp>(expr.op); 1472 return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1}); 1473 } 1474 } 1475 llvm_unreachable("unexpected expression kind in build"); 1476 } 1477 1478 } // namespace sparse_tensor 1479 } // namespace mlir 1480