1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arith/IR/Arith.h" 11 #include "mlir/Dialect/Complex/IR/Complex.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 15 #include "mlir/IR/Operation.h" 16 #include "llvm/Support/Debug.h" 17 #include <optional> 18 19 namespace mlir { 20 namespace sparse_tensor { 21 22 enum class ExpArity { 23 kNullary, 24 kUnary, 25 kBinary, 26 }; 27 28 static ExpArity getExpArity(TensorExp::Kind k) { 29 switch (k) { 30 // Leaf. 31 case TensorExp::Kind::kTensor: 32 case TensorExp::Kind::kInvariant: 33 case TensorExp::Kind::kLoopVar: 34 case TensorExp::Kind::kSynZero: 35 return ExpArity::kNullary; 36 case TensorExp::Kind::kAbsF: 37 case TensorExp::Kind::kAbsC: 38 case TensorExp::Kind::kAbsI: 39 case TensorExp::Kind::kCeilF: 40 case TensorExp::Kind::kFloorF: 41 case TensorExp::Kind::kSqrtF: 42 case TensorExp::Kind::kSqrtC: 43 case TensorExp::Kind::kExpm1F: 44 case TensorExp::Kind::kExpm1C: 45 case TensorExp::Kind::kLog1pF: 46 case TensorExp::Kind::kLog1pC: 47 case TensorExp::Kind::kSinF: 48 case TensorExp::Kind::kSinC: 49 case TensorExp::Kind::kTanhF: 50 case TensorExp::Kind::kTanhC: 51 case TensorExp::Kind::kTruncF: 52 case TensorExp::Kind::kExtF: 53 case TensorExp::Kind::kCastFS: 54 case TensorExp::Kind::kCastFU: 55 case TensorExp::Kind::kCastSF: 56 case TensorExp::Kind::kCastUF: 57 case TensorExp::Kind::kCastS: 58 case TensorExp::Kind::kCastU: 59 case TensorExp::Kind::kCastIdx: 60 case TensorExp::Kind::kTruncI: 61 case TensorExp::Kind::kCIm: 62 case TensorExp::Kind::kCRe: 63 case TensorExp::Kind::kBitCast: 64 case TensorExp::Kind::kBinaryBranch: 65 case TensorExp::Kind::kUnary: 66 case TensorExp::Kind::kSelect: 67 case TensorExp::Kind::kNegF: 68 case TensorExp::Kind::kNegC: 69 case TensorExp::Kind::kNegI: 70 return ExpArity::kUnary; 71 // Binary operations. 72 case TensorExp::Kind::kDivF: 73 case TensorExp::Kind::kDivC: 74 case TensorExp::Kind::kDivS: 75 case TensorExp::Kind::kDivU: 76 case TensorExp::Kind::kShrS: 77 case TensorExp::Kind::kShrU: 78 case TensorExp::Kind::kShlI: 79 case TensorExp::Kind::kMulF: 80 case TensorExp::Kind::kMulC: 81 case TensorExp::Kind::kMulI: 82 case TensorExp::Kind::kAndI: 83 case TensorExp::Kind::kAddF: 84 case TensorExp::Kind::kAddC: 85 case TensorExp::Kind::kAddI: 86 case TensorExp::Kind::kOrI: 87 case TensorExp::Kind::kXorI: 88 case TensorExp::Kind::kBinary: 89 case TensorExp::Kind::kReduce: 90 case TensorExp::Kind::kSubF: 91 case TensorExp::Kind::kSubC: 92 case TensorExp::Kind::kSubI: 93 case TensorExp::Kind::kCmpF: 94 case TensorExp::Kind::kCmpI: 95 return ExpArity::kBinary; 96 } 97 llvm_unreachable("unexpected kind"); 98 } 99 100 //===----------------------------------------------------------------------===// 101 // Constructors. 102 //===----------------------------------------------------------------------===// 103 104 TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v, 105 Operation *o, Attribute a) 106 : kind(k), val(v), op(o) { 107 switch (kind) { 108 // Leaf. 109 case TensorExp::Kind::kTensor: 110 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 111 tensor = x; 112 return; 113 case TensorExp::Kind::kSynZero: 114 assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o); 115 return; 116 case TensorExp::Kind::kInvariant: 117 assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o); 118 return; 119 case TensorExp::Kind::kLoopVar: 120 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 121 loop = x; 122 return; 123 // Unary operations. 124 case TensorExp::Kind::kAbsF: 125 case TensorExp::Kind::kAbsC: 126 case TensorExp::Kind::kAbsI: 127 case TensorExp::Kind::kCeilF: 128 case TensorExp::Kind::kFloorF: 129 case TensorExp::Kind::kSqrtF: 130 case TensorExp::Kind::kSqrtC: 131 case TensorExp::Kind::kExpm1F: 132 case TensorExp::Kind::kExpm1C: 133 case TensorExp::Kind::kLog1pF: 134 case TensorExp::Kind::kLog1pC: 135 case TensorExp::Kind::kSinF: 136 case TensorExp::Kind::kSinC: 137 case TensorExp::Kind::kTanhF: 138 case TensorExp::Kind::kTanhC: 139 case TensorExp::Kind::kNegF: 140 case TensorExp::Kind::kNegC: 141 case TensorExp::Kind::kNegI: 142 case TensorExp::Kind::kCIm: 143 case TensorExp::Kind::kCRe: 144 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); 145 children.e0 = x; 146 children.e1 = y; 147 return; 148 case TensorExp::Kind::kTruncF: 149 case TensorExp::Kind::kExtF: 150 case TensorExp::Kind::kCastFS: 151 case TensorExp::Kind::kCastFU: 152 case TensorExp::Kind::kCastSF: 153 case TensorExp::Kind::kCastUF: 154 case TensorExp::Kind::kCastS: 155 case TensorExp::Kind::kCastU: 156 case TensorExp::Kind::kCastIdx: 157 case TensorExp::Kind::kTruncI: 158 case TensorExp::Kind::kBitCast: 159 assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o); 160 children.e0 = x; 161 children.e1 = y; 162 return; 163 case TensorExp::Kind::kBinaryBranch: 164 case TensorExp::Kind::kSelect: 165 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o); 166 children.e0 = x; 167 children.e1 = y; 168 return; 169 case TensorExp::Kind::kUnary: 170 // No assertion on y can be made, as the branching paths involve both 171 // a unary (`mapSet`) and binary (`disjSet`) pathway. 172 assert(x != detail::kInvalidId && !v && o); 173 children.e0 = x; 174 children.e1 = y; 175 return; 176 // Binary operations. 177 case TensorExp::Kind::kMulF: 178 case TensorExp::Kind::kMulC: 179 case TensorExp::Kind::kMulI: 180 case TensorExp::Kind::kDivF: 181 case TensorExp::Kind::kDivC: 182 case TensorExp::Kind::kDivS: 183 case TensorExp::Kind::kDivU: 184 case TensorExp::Kind::kAddF: 185 case TensorExp::Kind::kAddC: 186 case TensorExp::Kind::kAddI: 187 case TensorExp::Kind::kSubF: 188 case TensorExp::Kind::kSubC: 189 case TensorExp::Kind::kSubI: 190 case TensorExp::Kind::kAndI: 191 case TensorExp::Kind::kOrI: 192 case TensorExp::Kind::kXorI: 193 case TensorExp::Kind::kShrS: 194 case TensorExp::Kind::kShrU: 195 case TensorExp::Kind::kShlI: 196 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); 197 children.e0 = x; 198 children.e1 = y; 199 return; 200 case TensorExp::Kind::kCmpF: 201 case TensorExp::Kind::kCmpI: 202 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); 203 attr = a; 204 children.e0 = x; 205 children.e1 = y; 206 return; 207 case TensorExp::Kind::kBinary: 208 case TensorExp::Kind::kReduce: 209 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o); 210 children.e0 = x; 211 children.e1 = y; 212 return; 213 } 214 llvm_unreachable("unexpected kind"); 215 } 216 217 Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops, 218 unsigned numFilterLoops, unsigned maxLvlRank) 219 : outTensor(numInputOutputTensors - 1), 220 syntheticTensor(numInputOutputTensors), 221 numTensors(numInputOutputTensors + 1), numNativeLoops(numNativeLoops), 222 numLoops(numNativeLoops + numFilterLoops), hasSparseOut(false), 223 lvlTypes(numTensors, 224 std::vector<DimLevelType>(numLoops, DimLevelType::Undef)), 225 loopToLvl(numTensors, 226 std::vector<std::optional<Level>>(numLoops, std::nullopt)), 227 lvlToLoop(numTensors, 228 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)), 229 loopToDependencies( 230 numLoops, std::vector<std::optional<std::pair<Level, DimLevelType>>>( 231 numTensors, std::nullopt)), 232 levelToDependentLoop(numTensors, std::vector<std::vector<LoopId>>( 233 maxLvlRank, std::vector<LoopId>())), 234 loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {} 235 236 //===----------------------------------------------------------------------===// 237 // Lattice methods. 238 //===----------------------------------------------------------------------===// 239 240 ExprId Merger::addTensorExp(TensorId t) { 241 assert(isValidTensorId(t)); 242 const ExprId eNew(tensorExps.size()); 243 tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId, 244 Value(), nullptr, nullptr); 245 return eNew; 246 } 247 248 ExprId Merger::addLoopVarExp(LoopId i) { 249 assert(isValidLoopId(i)); 250 const ExprId eNew(tensorExps.size()); 251 tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId, 252 Value(), nullptr, nullptr); 253 return eNew; 254 } 255 256 ExprId Merger::addInvariantExp(Value v) { 257 const ExprId eNew(tensorExps.size()); 258 tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId, 259 detail::kInvalidId, v, nullptr, nullptr); 260 return eNew; 261 } 262 263 ExprId Merger::addSynZeroExp() { 264 const ExprId eNew(tensorExps.size()); 265 tensorExps.emplace_back(TensorExp::Kind::kSynZero, detail::kInvalidId, 266 detail::kInvalidId, Value(), nullptr, nullptr); 267 return eNew; 268 } 269 270 ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op, 271 Attribute attr) { 272 assert(k > TensorExp::Kind::kLoopVar); 273 const ExprId eNew(tensorExps.size()); 274 tensorExps.emplace_back(k, e0, e1, Value(), op, attr); 275 return eNew; 276 } 277 278 ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op, 279 Attribute attr) { 280 assert(k > TensorExp::Kind::kLoopVar); 281 const ExprId eNew(tensorExps.size()); 282 tensorExps.emplace_back(k, e, detail::kInvalidId, v, op, attr); 283 return eNew; 284 } 285 286 LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) { 287 const LatPointId pNew(latPoints.size()); 288 const unsigned size = numLoops * numTensors; 289 const TensorLoopId b = makeTensorLoopId(t, i); 290 latPoints.emplace_back(size, e); 291 latPoints[pNew].bits.set(b); 292 return pNew; 293 } 294 295 LatPointId Merger::addLat(const BitVector &bits, ExprId e) { 296 assert(bits.size() == numLoops * numTensors); 297 const LatPointId pNew(latPoints.size()); 298 latPoints.emplace_back(bits, e); 299 return pNew; 300 } 301 302 LatSetId Merger::addSet() { 303 const LatSetId sNew(latSets.size()); 304 latSets.emplace_back(); 305 return sNew; 306 } 307 308 LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1, 309 Operation *op) { 310 TensorExp::Kind kind = exp(e).kind; 311 Attribute attr = exp(e).attr; 312 const LatPointId pNew(latPoints.size()); 313 const auto &point0 = lat(p0); 314 const auto &point1 = lat(p1); 315 BitVector bits(point0.bits); 316 bits |= point1.bits; 317 const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr); 318 latPoints.emplace_back(bits, ne); 319 return pNew; 320 } 321 322 LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { 323 const LatSetId sNew = addSet(); 324 auto &setNew = latSets[sNew]; 325 for (const LatPointId p0 : set(s0)) 326 for (const LatPointId p1 : set(s1)) 327 setNew.push_back(conjLat(e, p0, p1, op)); 328 return sNew; 329 } 330 331 LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { 332 const LatSetId sNew = conjSet(e, s0, s1, op); 333 TensorExp::Kind kind = exp(e).kind; 334 335 // Followed by all in s0. 336 latSets[sNew].append(latSets[s0]); 337 // Map binary 0-y to unary -y. 338 // TODO: move this if-else logic into buildLattices 339 if (kind == TensorExp::Kind::kSubF) 340 s1 = mapSet(TensorExp::Kind::kNegF, s1); 341 else if (kind == TensorExp::Kind::kSubC) 342 s1 = mapSet(TensorExp::Kind::kNegC, s1); 343 else if (kind == TensorExp::Kind::kSubI) 344 s1 = mapSet(TensorExp::Kind::kNegI, s1); 345 // Followed by all in s1. 346 latSets[sNew].append(latSets[s1]); 347 return sNew; 348 } 349 350 LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) { 351 assert(exp(e).kind == TensorExp::Kind::kCmpI || 352 exp(e).kind == TensorExp::Kind::kCmpF); 353 const LatSetId sNew = conjSet(e, s0, s1, nullptr); 354 355 ExprId e0 = exp(e).children.e0; 356 ExprId e1 = exp(e).children.e1; 357 if (exp(e0).kind == TensorExp::Kind::kSynZero || 358 exp(e1).kind == TensorExp::Kind::kSynZero) { 359 // lhs and rhs can't be synthetic zero at the same time. 360 assert(exp(e0).kind != exp(e1).kind); 361 // If one of the operands has already been assigned to zero (the 362 // element is absent in the corresponding operand), then we do not 363 // need to build disjunctive set for it. 364 return sNew; 365 } 366 367 auto lhsSet = mapBinWithSynZeroSet(e, s0, false); 368 auto rhsSet = mapBinWithSynZeroSet(e, s1, true); 369 latSets[sNew].append(latSets[lhsSet]); 370 latSets[sNew].append(latSets[rhsSet]); 371 return sNew; 372 } 373 374 LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, 375 bool includeLeft, TensorExp::Kind ltrans, 376 Operation *opleft, bool includeRight, 377 TensorExp::Kind rtrans, Operation *opright) { 378 const LatSetId sNew = conjSet(e, s0, s1, orig); 379 // Left Region. 380 if (includeLeft) { 381 if (opleft) 382 s0 = mapSet(ltrans, s0, Value(), opleft); 383 latSets[sNew].append(latSets[s0]); 384 } 385 // Right Region. 386 if (includeRight) { 387 if (opright) 388 s1 = mapSet(rtrans, s1, Value(), opright); 389 latSets[sNew].append(latSets[s1]); 390 } 391 return sNew; 392 } 393 394 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v, 395 Operation *op) { 396 assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect); 397 const LatSetId sNew = addSet(); 398 auto &setNew = latSets[sNew]; 399 for (const LatPointId p : set(s0)) { 400 const auto &point = latPoints[p]; 401 setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op))); 402 } 403 return sNew; 404 } 405 406 LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) { 407 TensorExp::Kind kind = exp(e).kind; 408 Attribute a = exp(e).attr; 409 assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI); 410 // Must be a binary operation. 411 const LatSetId sNew = addSet(); 412 auto &setNew = latSets[sNew]; 413 const ExprId zeroExp = addSynZeroExp(); 414 for (const LatPointId p : set(s0)) { 415 const auto &point = latPoints[p]; 416 ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a) 417 : addExp(kind, point.exp, zeroExp, nullptr, a); 418 setNew.push_back(addLat(point.bits, newExp)); 419 } 420 return sNew; 421 } 422 423 LatSetId Merger::optimizeSet(LatSetId s0) { 424 const LatSetId sNew = addSet(); 425 auto &setNew = latSets[sNew]; 426 const auto &set0 = set(s0); 427 assert(!set0.empty()); 428 const LatPointId p0 = set0[0]; 429 for (const LatPointId p1 : set0) { 430 bool add = true; 431 if (p0 != p1) { 432 // Check whether this is a straightforward copy. 433 if (expIsTensor(latPoints[p1].exp, outTensor)) 434 continue; 435 // Check whether this conjunction is already covered. 436 for (const LatPointId p2 : setNew) { 437 assert(!latGT(p1, p2)); // Lj => Li would be bad 438 if (onlyDenseDiff(p2, p1)) { 439 add = false; 440 break; 441 } 442 } 443 assert(!add || latGT(p0, p1)); 444 } 445 if (add) 446 setNew.push_back(p1); 447 } 448 for (const LatPointId p : setNew) 449 latPoints[p].simple = simplifyCond(sNew, p); 450 return sNew; 451 } 452 453 BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { 454 // First determine if this lattice point is a *singleton*, i.e., 455 // the last point in a lattice, no other is less than this one. 456 bool isSingleton = true; 457 for (const LatPointId p1 : set(s0)) { 458 if (p0 != p1 && latGT(p0, p1)) { 459 isSingleton = false; 460 break; 461 } 462 } 463 464 BitVector simple(latPoints[p0].bits); 465 bool reset = isSingleton && hasAnySparse(simple); 466 const TensorLoopId be = simple.size(); 467 TensorLoopId offset = 0; // relative to the end 468 if (!reset) 469 // Starts resetting from a dense level, so that the first bit (if kept) 470 // is not undefined level-type. 471 for (unsigned b = 0; b < be; b++) { 472 if (simple[b] && isDenseDLT(getLvlType(TensorLoopId{b}))) { 473 offset = be - b - 1; // relative to the end 474 break; 475 } 476 } 477 478 // Now apply the two basic rules. We also iterate the bits reversely to always 479 // keep the rightmost bit (which could possibly be a synthetic tensor). 480 for (unsigned b = be - 1 - offset, i = 0; i < be; 481 b = b == 0 ? be - 1 : b - 1, i++) { 482 // Slice on dense level has `locate` property as well, and can be optimized. 483 if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) { 484 const auto dlt = getLvlType(b); 485 if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) && 486 !isCompressedWithHiDLT(dlt)) { 487 if (reset) 488 simple.reset(b); 489 reset = true; 490 } 491 } 492 } 493 return simple; 494 } 495 496 bool Merger::latGT(LatPointId i, LatPointId j) const { 497 const BitVector &bitsi = lat(i).bits; 498 const BitVector &bitsj = lat(j).bits; 499 assert(bitsi.size() == bitsj.size()); 500 if (bitsi.count() > bitsj.count()) { 501 for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) 502 if (bitsj[b] && !bitsi[b]) 503 return false; 504 return true; 505 } 506 return false; 507 } 508 509 bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { 510 BitVector tmp(latPoints[j].bits); 511 tmp ^= latPoints[i].bits; 512 return !hasAnySparse(tmp); 513 } 514 515 bool Merger::expContainsTensor(ExprId e, TensorId t) const { 516 const auto &expr = exp(e); 517 // First we check `expIsTensor`. 518 if (expr.kind == TensorExp::Kind::kTensor) 519 return expr.tensor == t; 520 521 switch (getExpArity(expr.kind)) { 522 case ExpArity::kNullary: 523 return false; 524 case ExpArity::kUnary: { 525 const ExprId e0 = expr.children.e0; 526 return expContainsTensor(e0, t); 527 } 528 case ExpArity::kBinary: { 529 const ExprId e0 = expr.children.e0; 530 const ExprId e1 = expr.children.e1; 531 return expContainsTensor(e0, t) || expContainsTensor(e1, t); 532 } 533 } 534 llvm_unreachable("unexpected arity"); 535 } 536 537 bool Merger::hasNegateOnOut(ExprId e) const { 538 const auto &expr = exp(e); 539 switch (expr.kind) { 540 case TensorExp::Kind::kNegF: 541 case TensorExp::Kind::kNegC: 542 case TensorExp::Kind::kNegI: 543 return expContainsTensor(expr.children.e0, outTensor); 544 case TensorExp::Kind::kSubF: 545 case TensorExp::Kind::kSubC: 546 case TensorExp::Kind::kSubI: 547 return expContainsTensor(expr.children.e1, outTensor) || 548 hasNegateOnOut(expr.children.e0); 549 default: { 550 switch (getExpArity(expr.kind)) { 551 case ExpArity::kNullary: 552 return false; 553 case ExpArity::kUnary: 554 return hasNegateOnOut(expr.children.e0); 555 case ExpArity::kBinary: 556 return hasNegateOnOut(expr.children.e0) || 557 hasNegateOnOut(expr.children.e1); 558 } 559 } 560 } 561 llvm_unreachable("unexpected kind"); 562 } 563 564 bool Merger::isSingleCondition(TensorId t, ExprId e) const { 565 assert(isValidTensorId(t)); 566 const auto &expr = exp(e); 567 switch (expr.kind) { 568 // Leaf. 569 case TensorExp::Kind::kTensor: 570 return expr.tensor == t; 571 case TensorExp::Kind::kInvariant: 572 case TensorExp::Kind::kLoopVar: 573 case TensorExp::Kind::kSynZero: 574 return false; 575 // Unary operations. 576 case TensorExp::Kind::kAbsF: 577 case TensorExp::Kind::kAbsC: 578 case TensorExp::Kind::kAbsI: 579 case TensorExp::Kind::kCeilF: 580 case TensorExp::Kind::kFloorF: 581 case TensorExp::Kind::kSqrtF: 582 case TensorExp::Kind::kSqrtC: 583 case TensorExp::Kind::kExpm1F: 584 case TensorExp::Kind::kExpm1C: 585 case TensorExp::Kind::kLog1pF: 586 case TensorExp::Kind::kLog1pC: 587 case TensorExp::Kind::kSinF: 588 case TensorExp::Kind::kSinC: 589 case TensorExp::Kind::kTanhF: 590 case TensorExp::Kind::kTanhC: 591 case TensorExp::Kind::kNegF: 592 case TensorExp::Kind::kNegC: 593 case TensorExp::Kind::kNegI: 594 case TensorExp::Kind::kTruncF: 595 case TensorExp::Kind::kExtF: 596 case TensorExp::Kind::kCastFS: 597 case TensorExp::Kind::kCastFU: 598 case TensorExp::Kind::kCastSF: 599 case TensorExp::Kind::kCastUF: 600 case TensorExp::Kind::kCastS: 601 case TensorExp::Kind::kCastU: 602 case TensorExp::Kind::kCastIdx: 603 case TensorExp::Kind::kTruncI: 604 case TensorExp::Kind::kCIm: 605 case TensorExp::Kind::kCRe: 606 case TensorExp::Kind::kBitCast: 607 case TensorExp::Kind::kUnary: 608 return isSingleCondition(t, expr.children.e0); 609 case TensorExp::Kind::kBinaryBranch: 610 case TensorExp::Kind::kSelect: 611 return false; 612 // Binary operations. 613 case TensorExp::Kind::kDivF: // note: x / c only 614 case TensorExp::Kind::kDivC: 615 case TensorExp::Kind::kDivS: 616 case TensorExp::Kind::kDivU: 617 assert(!maybeZero(expr.children.e1)); 618 return isSingleCondition(t, expr.children.e0); 619 case TensorExp::Kind::kShrS: // note: x >> inv only 620 case TensorExp::Kind::kShrU: 621 case TensorExp::Kind::kShlI: 622 assert(isInvariant(expr.children.e1)); 623 return isSingleCondition(t, expr.children.e0); 624 case TensorExp::Kind::kMulF: 625 case TensorExp::Kind::kMulC: 626 case TensorExp::Kind::kMulI: 627 case TensorExp::Kind::kAndI: 628 case TensorExp::Kind::kReduce: 629 if (isSingleCondition(t, expr.children.e0)) 630 return isSingleCondition(t, expr.children.e1) || 631 isInvariant(expr.children.e1); 632 if (isSingleCondition(t, expr.children.e1)) 633 return isInvariant(expr.children.e0); 634 return false; 635 case TensorExp::Kind::kAddF: 636 case TensorExp::Kind::kAddC: 637 case TensorExp::Kind::kAddI: 638 return isSingleCondition(t, expr.children.e0) && 639 isSingleCondition(t, expr.children.e1); 640 case TensorExp::Kind::kSubF: 641 case TensorExp::Kind::kSubC: 642 case TensorExp::Kind::kSubI: 643 case TensorExp::Kind::kOrI: 644 case TensorExp::Kind::kXorI: 645 case TensorExp::Kind::kCmpF: 646 case TensorExp::Kind::kCmpI: 647 case TensorExp::Kind::kBinary: 648 return false; 649 } 650 llvm_unreachable("unexpected kind"); 651 } 652 653 bool Merger::hasAnySparse(const BitVector &bits) const { 654 for (TensorLoopId b : bits.set_bits()) { 655 const auto dlt = getLvlType(b); 656 if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || 657 isCompressedWithHiDLT(dlt)) 658 return true; 659 } 660 return hasSparseIdxReduction(bits); 661 } 662 663 bool Merger::hasSparseIdxReduction(const BitVector &bits) const { 664 for (TensorLoopId b : bits.set_bits()) 665 if (isSparseLvlWithNonTrivialIdxExp(b)) 666 return true; 667 return false; 668 } 669 670 #ifndef NDEBUG 671 672 //===----------------------------------------------------------------------===// 673 // Print methods (for debugging). 674 //===----------------------------------------------------------------------===// 675 676 static const char *kindToOpSymbol(TensorExp::Kind kind) { 677 switch (kind) { 678 // Leaf. 679 case TensorExp::Kind::kTensor: 680 return "tensor"; 681 case TensorExp::Kind::kInvariant: 682 return "invariant"; 683 case TensorExp::Kind::kLoopVar: 684 return "index"; 685 case TensorExp::Kind::kSynZero: 686 return "0"; 687 // Unary operations. 688 case TensorExp::Kind::kAbsF: 689 case TensorExp::Kind::kAbsC: 690 case TensorExp::Kind::kAbsI: 691 return "abs"; 692 case TensorExp::Kind::kCeilF: 693 return "ceil"; 694 case TensorExp::Kind::kFloorF: 695 return "floor"; 696 case TensorExp::Kind::kSqrtF: 697 case TensorExp::Kind::kSqrtC: 698 return "sqrt"; 699 case TensorExp::Kind::kExpm1F: 700 case TensorExp::Kind::kExpm1C: 701 return "expm1"; 702 case TensorExp::Kind::kLog1pF: 703 case TensorExp::Kind::kLog1pC: 704 return "log1p"; 705 case TensorExp::Kind::kSinF: 706 case TensorExp::Kind::kSinC: 707 return "sin"; 708 case TensorExp::Kind::kTanhF: 709 case TensorExp::Kind::kTanhC: 710 return "tanh"; 711 case TensorExp::Kind::kNegF: 712 case TensorExp::Kind::kNegC: 713 case TensorExp::Kind::kNegI: 714 return "-"; 715 case TensorExp::Kind::kTruncF: 716 case TensorExp::Kind::kExtF: 717 case TensorExp::Kind::kCastFS: 718 case TensorExp::Kind::kCastFU: 719 case TensorExp::Kind::kCastSF: 720 case TensorExp::Kind::kCastUF: 721 case TensorExp::Kind::kCastS: 722 case TensorExp::Kind::kCastU: 723 case TensorExp::Kind::kCastIdx: 724 case TensorExp::Kind::kTruncI: 725 case TensorExp::Kind::kCIm: 726 return "complex.im"; 727 case TensorExp::Kind::kCRe: 728 return "complex.re"; 729 case TensorExp::Kind::kBitCast: 730 return "cast"; 731 case TensorExp::Kind::kBinaryBranch: 732 return "binary_branch"; 733 case TensorExp::Kind::kUnary: 734 return "unary"; 735 case TensorExp::Kind::kSelect: 736 return "select"; 737 // Binary operations. 738 case TensorExp::Kind::kMulF: 739 case TensorExp::Kind::kMulC: 740 case TensorExp::Kind::kMulI: 741 return "*"; 742 case TensorExp::Kind::kDivF: 743 case TensorExp::Kind::kDivC: 744 case TensorExp::Kind::kDivS: 745 case TensorExp::Kind::kDivU: 746 return "/"; 747 case TensorExp::Kind::kAddF: 748 case TensorExp::Kind::kAddC: 749 case TensorExp::Kind::kAddI: 750 return "+"; 751 case TensorExp::Kind::kSubF: 752 case TensorExp::Kind::kSubC: 753 case TensorExp::Kind::kSubI: 754 return "-"; 755 case TensorExp::Kind::kAndI: 756 return "&"; 757 case TensorExp::Kind::kOrI: 758 return "|"; 759 case TensorExp::Kind::kXorI: 760 return "^"; 761 case TensorExp::Kind::kShrS: 762 return "a>>"; 763 case TensorExp::Kind::kShrU: 764 return ">>"; 765 case TensorExp::Kind::kShlI: 766 return "<<"; 767 case TensorExp::Kind::kCmpF: 768 case TensorExp::Kind::kCmpI: 769 return "cmp"; 770 case TensorExp::Kind::kBinary: 771 return "binary"; 772 case TensorExp::Kind::kReduce: 773 return "reduce"; 774 } 775 llvm_unreachable("unexpected kind for symbol"); 776 } 777 778 void Merger::dumpExp(ExprId e) const { 779 const auto &expr = exp(e); 780 switch (expr.kind) { 781 // Leaf. 782 case TensorExp::Kind::kTensor: 783 if (expr.tensor == syntheticTensor) 784 llvm::dbgs() << "synthetic_"; 785 else if (expr.tensor == outTensor) 786 llvm::dbgs() << "output_"; 787 llvm::dbgs() << "tensor_" << expr.tensor; 788 break; 789 case TensorExp::Kind::kInvariant: 790 llvm::dbgs() << "invariant"; 791 break; 792 case TensorExp::Kind::kSynZero: 793 llvm::dbgs() << "0"; 794 break; 795 case TensorExp::Kind::kLoopVar: 796 llvm::dbgs() << "loopvar_" << expr.loop; 797 break; 798 // Unary operations. 799 case TensorExp::Kind::kAbsF: 800 case TensorExp::Kind::kAbsC: 801 case TensorExp::Kind::kAbsI: 802 case TensorExp::Kind::kCeilF: 803 case TensorExp::Kind::kFloorF: 804 case TensorExp::Kind::kSqrtF: 805 case TensorExp::Kind::kSqrtC: 806 case TensorExp::Kind::kExpm1F: 807 case TensorExp::Kind::kExpm1C: 808 case TensorExp::Kind::kLog1pF: 809 case TensorExp::Kind::kLog1pC: 810 case TensorExp::Kind::kSinF: 811 case TensorExp::Kind::kSinC: 812 case TensorExp::Kind::kTanhF: 813 case TensorExp::Kind::kTanhC: 814 case TensorExp::Kind::kNegF: 815 case TensorExp::Kind::kNegC: 816 case TensorExp::Kind::kNegI: 817 case TensorExp::Kind::kTruncF: 818 case TensorExp::Kind::kExtF: 819 case TensorExp::Kind::kCastFS: 820 case TensorExp::Kind::kCastFU: 821 case TensorExp::Kind::kCastSF: 822 case TensorExp::Kind::kCastUF: 823 case TensorExp::Kind::kCastS: 824 case TensorExp::Kind::kCastU: 825 case TensorExp::Kind::kCastIdx: 826 case TensorExp::Kind::kTruncI: 827 case TensorExp::Kind::kCIm: 828 case TensorExp::Kind::kCRe: 829 case TensorExp::Kind::kBitCast: 830 case TensorExp::Kind::kBinaryBranch: 831 case TensorExp::Kind::kUnary: 832 case TensorExp::Kind::kSelect: 833 llvm::dbgs() << kindToOpSymbol(expr.kind) << " "; 834 dumpExp(expr.children.e0); 835 break; 836 // Binary operations. 837 case TensorExp::Kind::kMulF: 838 case TensorExp::Kind::kMulC: 839 case TensorExp::Kind::kMulI: 840 case TensorExp::Kind::kDivF: 841 case TensorExp::Kind::kDivC: 842 case TensorExp::Kind::kDivS: 843 case TensorExp::Kind::kDivU: 844 case TensorExp::Kind::kAddF: 845 case TensorExp::Kind::kAddC: 846 case TensorExp::Kind::kAddI: 847 case TensorExp::Kind::kSubF: 848 case TensorExp::Kind::kSubC: 849 case TensorExp::Kind::kSubI: 850 case TensorExp::Kind::kAndI: 851 case TensorExp::Kind::kOrI: 852 case TensorExp::Kind::kXorI: 853 case TensorExp::Kind::kShrS: 854 case TensorExp::Kind::kShrU: 855 case TensorExp::Kind::kShlI: 856 case TensorExp::Kind::kCmpF: 857 case TensorExp::Kind::kCmpI: 858 case TensorExp::Kind::kBinary: 859 case TensorExp::Kind::kReduce: 860 llvm::dbgs() << "("; 861 dumpExp(expr.children.e0); 862 llvm::dbgs() << " " << kindToOpSymbol(expr.kind); 863 if (expr.attr) 864 llvm::dbgs() << "{" << expr.attr << "}"; 865 llvm::dbgs() << " "; 866 dumpExp(expr.children.e1); 867 llvm::dbgs() << ")"; 868 break; 869 } 870 } 871 872 void Merger::dumpLat(LatPointId p) const { 873 const auto &point = lat(p); 874 llvm::dbgs() << "lat("; 875 dumpBits(point.bits); 876 llvm::dbgs() << " :"; 877 dumpBits(point.simple); 878 llvm::dbgs() << " : "; 879 dumpExp(point.exp); 880 llvm::dbgs() << " )\n"; 881 } 882 883 void Merger::dumpSet(LatSetId s) const { 884 const auto &ss = set(s); 885 llvm::dbgs() << "{ #" << ss.size() << "\n"; 886 for (const LatPointId p : ss) { 887 llvm::dbgs() << " "; 888 dumpLat(p); 889 } 890 llvm::dbgs() << "}\n"; 891 } 892 893 void Merger::dumpBits(const BitVector &bits) const { 894 for (TensorLoopId b = 0, be = bits.size(); b < be; b++) { 895 if (bits[b]) { 896 const TensorId t = tensor(b); 897 const LoopId i = loop(b); 898 const auto dlt = lvlTypes[t][i]; 899 if (isLvlWithNonTrivialIdxExp(b)) 900 llvm::dbgs() << " DEP_" << t << "_" << i; 901 else 902 llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); 903 } 904 } 905 } 906 907 #endif // NDEBUG 908 909 //===----------------------------------------------------------------------===// 910 // Builder methods. 911 //===----------------------------------------------------------------------===// 912 913 LatSetId Merger::buildLattices(ExprId e, LoopId i) { 914 // NOTE: The `expr` reference will be invalidated by recursive calls 915 // (and any other method that may add new expressions); therefore, the 916 // code below must make sure to copy fields of `expr` into local variables 917 // before making any recursive calls. 918 const auto &expr = exp(e); 919 const TensorExp::Kind kind = expr.kind; 920 switch (kind) { 921 // Leaf. 922 case TensorExp::Kind::kTensor: 923 case TensorExp::Kind::kInvariant: 924 case TensorExp::Kind::kSynZero: 925 case TensorExp::Kind::kLoopVar: { 926 // Either the loop-var is really used in the tensor expression, or it is 927 // set to the undefined loop-var in that level. An invariant expression, 928 // a proper index value, and a truly dynamic sparse output tensor are set 929 // to a synthetic tensor with undefined indices only to ensure the 930 // iteration space is not skipped as a result of their contents. 931 const LatSetId s = addSet(); 932 TensorId t = syntheticTensor; 933 if (kind == TensorExp::Kind::kTensor) { 934 t = expr.tensor; 935 if (hasSparseOut && t == outTensor) 936 t = syntheticTensor; 937 } 938 latSets[s].push_back(addLat(t, i, e)); 939 return s; 940 } 941 // Unary operations. 942 case TensorExp::Kind::kAbsF: 943 case TensorExp::Kind::kAbsC: 944 case TensorExp::Kind::kAbsI: 945 case TensorExp::Kind::kCeilF: 946 case TensorExp::Kind::kFloorF: 947 case TensorExp::Kind::kSqrtF: 948 case TensorExp::Kind::kSqrtC: 949 case TensorExp::Kind::kExpm1F: 950 case TensorExp::Kind::kExpm1C: 951 case TensorExp::Kind::kLog1pF: 952 case TensorExp::Kind::kLog1pC: 953 case TensorExp::Kind::kSinF: 954 case TensorExp::Kind::kSinC: 955 case TensorExp::Kind::kTanhF: 956 case TensorExp::Kind::kTanhC: 957 case TensorExp::Kind::kNegF: 958 case TensorExp::Kind::kNegC: 959 case TensorExp::Kind::kNegI: 960 case TensorExp::Kind::kTruncF: 961 case TensorExp::Kind::kExtF: 962 case TensorExp::Kind::kCastFS: 963 case TensorExp::Kind::kCastFU: 964 case TensorExp::Kind::kCastSF: 965 case TensorExp::Kind::kCastUF: 966 case TensorExp::Kind::kCastS: 967 case TensorExp::Kind::kCastU: 968 case TensorExp::Kind::kCastIdx: 969 case TensorExp::Kind::kTruncI: 970 case TensorExp::Kind::kCIm: 971 case TensorExp::Kind::kCRe: 972 case TensorExp::Kind::kBitCast: 973 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 974 // lattice set of the operand through the operator into a new set. 975 // 976 // -y|!y | y | 977 // --+---+---+ 978 // | 0 |-y | 979 { 980 const ExprId e0 = expr.children.e0; 981 const Value v = expr.val; 982 return mapSet(kind, buildLattices(e0, i), v); 983 } 984 case TensorExp::Kind::kBinaryBranch: 985 case TensorExp::Kind::kSelect: 986 // The left or right half of a binary operation which has already 987 // been split into separate operations for each region. 988 { 989 const ExprId e0 = expr.children.e0; 990 Operation *const op = expr.op; 991 return mapSet(kind, buildLattices(e0, i), Value(), op); 992 } 993 case TensorExp::Kind::kUnary: 994 // A custom unary operation. 995 // 996 // op y| !y | y | 997 // ----+----------+------------+ 998 // | absent() | present(y) | 999 { 1000 const ExprId e0 = expr.children.e0; 1001 UnaryOp unop = cast<UnaryOp>(expr.op); 1002 const LatSetId child0 = buildLattices(e0, i); 1003 Region &absentRegion = unop.getAbsentRegion(); 1004 if (absentRegion.empty()) { 1005 // Simple mapping over existing values. 1006 return mapSet(kind, child0, Value(), unop); 1007 } 1008 // Use a disjunction with `unop` on the left and the absent value as an 1009 // invariant on the right. 1010 Block &absentBlock = absentRegion.front(); 1011 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); 1012 const Value absentVal = absentYield.getResult(); 1013 const ExprId rhs = addInvariantExp(absentVal); 1014 return disjSet(e, child0, buildLattices(rhs, i), unop); 1015 } 1016 // Binary operations. 1017 case TensorExp::Kind::kMulF: 1018 case TensorExp::Kind::kMulC: 1019 case TensorExp::Kind::kMulI: 1020 case TensorExp::Kind::kAndI: 1021 // A multiplicative operation only needs to be performed 1022 // for the conjunction of sparse iteration spaces. 1023 // 1024 // x*y|!y | y | 1025 // ---+---+---+ 1026 // !x | 0 | 0 | 1027 // x | 0 |x*y| 1028 // 1029 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. 1030 { 1031 const ExprId e0 = expr.children.e0; 1032 const ExprId e1 = expr.children.e1; 1033 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1034 } 1035 case TensorExp::Kind::kDivF: 1036 case TensorExp::Kind::kDivC: 1037 case TensorExp::Kind::kDivS: 1038 case TensorExp::Kind::kDivU: 1039 // A division is tricky, since 0/0, 0/c, c/0 all have 1040 // specific outcomes for floating-point and integers. 1041 // Thus, we need to traverse the full iteration space. 1042 // 1043 // x/y|!y | y | 1044 // ---+---+---+ 1045 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 1046 // x |x/0|x/y| INT: x/0=exception for any x 1047 // 1048 // TODO: for now we "fixed" this by only accepting x/c cases 1049 // during expression building, so that the conjunction 1050 // rules applies (viz. x/c = x*(1/c) as far as lattice 1051 // construction is concerned). 1052 { 1053 const ExprId e0 = expr.children.e0; 1054 const ExprId e1 = expr.children.e1; 1055 assert(!maybeZero(e1)); 1056 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1057 } 1058 case TensorExp::Kind::kAddF: 1059 case TensorExp::Kind::kAddC: 1060 case TensorExp::Kind::kAddI: 1061 case TensorExp::Kind::kSubF: 1062 case TensorExp::Kind::kSubC: 1063 case TensorExp::Kind::kSubI: 1064 case TensorExp::Kind::kOrI: 1065 case TensorExp::Kind::kXorI: 1066 // An additive operation needs to be performed 1067 // for the disjunction of sparse iteration spaces. 1068 // 1069 // x+y|!y | y | x-y|!y | y | 1070 // ---+---+---+ ---+---+---+ 1071 // !x | 0 | y | !x | 0 |-y | 1072 // x | x |x+y| x | x |x-y| 1073 { 1074 const ExprId e0 = expr.children.e0; 1075 const ExprId e1 = expr.children.e1; 1076 return disjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1077 } 1078 case TensorExp::Kind::kCmpF: 1079 case TensorExp::Kind::kCmpI: 1080 // An comparison operation needs to be performed 1081 // for the disjunction of sparse iteration spaces. 1082 // 1083 // x < y | !y | y | 1084 // -------+-------+-------+ 1085 // !x | 0 | 0 < y | 1086 // x | x < 0 | x < y | 1087 { 1088 const ExprId e0 = expr.children.e0; 1089 const ExprId e1 = expr.children.e1; 1090 return disjSetWithZero(e, buildLattices(e0, i), buildLattices(e1, i)); 1091 } 1092 case TensorExp::Kind::kShrS: 1093 case TensorExp::Kind::kShrU: 1094 case TensorExp::Kind::kShlI: 1095 // A shift operation by an invariant amount (viz. tensor expressions 1096 // can only occur at the left-hand-side of the operator) can be handled 1097 // with the conjuction rule. 1098 { 1099 const ExprId e0 = expr.children.e0; 1100 const ExprId e1 = expr.children.e1; 1101 assert(isInvariant(e1)); 1102 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i)); 1103 } 1104 case TensorExp::Kind::kBinary: 1105 // A custom binary operation. 1106 // 1107 // x op y| !y | y | 1108 // ------+---------+--------------+ 1109 // !x | empty | right(y) | 1110 // x | left(x) | overlap(x,y) | 1111 { 1112 const ExprId e0 = expr.children.e0; 1113 const ExprId e1 = expr.children.e1; 1114 BinaryOp binop = cast<BinaryOp>(expr.op); 1115 const LatSetId child0 = buildLattices(e0, i); 1116 const LatSetId child1 = buildLattices(e1, i); 1117 Region &leftRegion = binop.getLeftRegion(); 1118 Region &rightRegion = binop.getRightRegion(); 1119 // Left Region. 1120 Operation *leftYield = nullptr; 1121 if (!leftRegion.empty()) { 1122 Block &leftBlock = leftRegion.front(); 1123 leftYield = leftBlock.getTerminator(); 1124 } 1125 // Right Region. 1126 Operation *rightYield = nullptr; 1127 if (!rightRegion.empty()) { 1128 Block &rightBlock = rightRegion.front(); 1129 rightYield = rightBlock.getTerminator(); 1130 } 1131 bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty(); 1132 bool includeRight = binop.getRightIdentity() || !rightRegion.empty(); 1133 return combiSet(e, child0, child1, binop, includeLeft, 1134 TensorExp::Kind::kBinaryBranch, leftYield, includeRight, 1135 TensorExp::Kind::kBinaryBranch, rightYield); 1136 } 1137 case TensorExp::Kind::kReduce: 1138 // A custom reduce operation. 1139 { 1140 const ExprId e0 = expr.children.e0; 1141 const ExprId e1 = expr.children.e1; 1142 Operation *const op = expr.op; 1143 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op); 1144 } 1145 } 1146 llvm_unreachable("unexpected expression kind"); 1147 } 1148 1149 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 1150 // Build the linalg semantics backward from yield. 1151 Operation *yield = op.getRegion().front().getTerminator(); 1152 assert(isa<linalg::YieldOp>(yield)); 1153 return buildTensorExp(op, yield->getOperand(0)); 1154 } 1155 1156 /// Only returns false if we are certain this is a nonzero. 1157 bool Merger::maybeZero(ExprId e) const { 1158 const auto &expr = exp(e); 1159 if (expr.kind == TensorExp::Kind::kInvariant) { 1160 if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) { 1161 ArrayAttr arrayAttr = c.getValue(); 1162 return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() && 1163 cast<FloatAttr>(arrayAttr[1]).getValue().isZero(); 1164 } 1165 if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>()) 1166 return c.value() == 0; 1167 if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>()) 1168 return c.value().isZero(); 1169 } 1170 return true; 1171 } 1172 1173 Type Merger::inferType(ExprId e, Value src) const { 1174 // Obtain the destination type from the cast node. 1175 Type dtp = exp(e).val.getType(); 1176 // Inspect source type. For vector types, apply the same 1177 // vectorization to the destination type. 1178 if (auto vtp = dyn_cast<VectorType>(src.getType())) 1179 return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); 1180 return dtp; 1181 } 1182 1183 /// Ensures that sparse compiler can generate code for expression. 1184 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) { 1185 // Arguments are always admissible. 1186 if (isa<BlockArgument>(v)) 1187 return true; 1188 // Accept index anywhere. 1189 Operation *def = v.getDefiningOp(); 1190 if (isa<linalg::IndexOp>(def)) 1191 return true; 1192 // Operation defined outside branch. 1193 if (def->getBlock() != block) 1194 return def->getBlock() != op->getBlock(); // invariant? 1195 // Operation defined within branch. Anything is accepted, 1196 // as long as all subexpressions are admissible. 1197 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) 1198 if (!isAdmissibleBranchExp(op, block, def->getOperand(i))) 1199 return false; 1200 return true; 1201 } 1202 1203 /// Ensures that sparse compiler can generate code for branch. 1204 static bool isAdmissibleBranch(Operation *op, Region ®ion) { 1205 if (region.empty()) 1206 return true; 1207 // Build the semi-ring branch semantics backward from yield. 1208 Operation *yield = region.front().getTerminator(); 1209 assert(isa<YieldOp>(yield)); 1210 return isAdmissibleBranchExp(op, ®ion.front(), yield->getOperand(0)); 1211 } 1212 1213 std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 1214 if (auto arg = dyn_cast<BlockArgument>(v)) { 1215 const TensorId tid = makeTensorId(arg.getArgNumber()); 1216 // Any argument of the generic op that is not marked as a scalar 1217 // argument is considered a tensor, indexed by the implicit loop 1218 // bounds. This includes rank-0 tensor arguments. 1219 if (arg.getOwner()->getParentOp() == op) { 1220 OpOperand &t = op->getOpOperand(tid); 1221 if (!op.isScalar(&t)) 1222 return addTensorExp(tid); 1223 v = t.get(); // get scalar value 1224 } 1225 // Any other argument (marked as scalar argument for the generic op 1226 // or belonging to an enveloping op) is considered invariant. 1227 return addInvariantExp(v); 1228 } 1229 // Something defined outside is invariant. 1230 Operation *def = v.getDefiningOp(); 1231 if (def->getBlock() != &op.getRegion().front()) 1232 return addInvariantExp(v); 1233 // Construct index operations. 1234 if (def->getNumOperands() == 0) { 1235 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 1236 return addLoopVarExp(makeLoopId(indexOp.getDim())); 1237 } 1238 // Construct unary operations if subexpression can be built. 1239 if (def->getNumOperands() == 1) { 1240 const auto x = buildTensorExp(op, def->getOperand(0)); 1241 if (x.has_value()) { 1242 const ExprId e = *x; 1243 if (isa<math::AbsFOp>(def)) 1244 return addExp(TensorExp::Kind::kAbsF, e); 1245 if (isa<complex::AbsOp>(def)) 1246 return addExp(TensorExp::Kind::kAbsC, e); 1247 if (isa<math::AbsIOp>(def)) 1248 return addExp(TensorExp::Kind::kAbsI, e); 1249 if (isa<math::CeilOp>(def)) 1250 return addExp(TensorExp::Kind::kCeilF, e); 1251 if (isa<math::FloorOp>(def)) 1252 return addExp(TensorExp::Kind::kFloorF, e); 1253 if (isa<math::SqrtOp>(def)) 1254 return addExp(TensorExp::Kind::kSqrtF, e); 1255 if (isa<complex::SqrtOp>(def)) 1256 return addExp(TensorExp::Kind::kSqrtC, e); 1257 if (isa<math::ExpM1Op>(def)) 1258 return addExp(TensorExp::Kind::kExpm1F, e); 1259 if (isa<complex::Expm1Op>(def)) 1260 return addExp(TensorExp::Kind::kExpm1C, e); 1261 if (isa<math::Log1pOp>(def)) 1262 return addExp(TensorExp::Kind::kLog1pF, e); 1263 if (isa<complex::Log1pOp>(def)) 1264 return addExp(TensorExp::Kind::kLog1pC, e); 1265 if (isa<math::SinOp>(def)) 1266 return addExp(TensorExp::Kind::kSinF, e); 1267 if (isa<complex::SinOp>(def)) 1268 return addExp(TensorExp::Kind::kSinC, e); 1269 if (isa<math::TanhOp>(def)) 1270 return addExp(TensorExp::Kind::kTanhF, e); 1271 if (isa<complex::TanhOp>(def)) 1272 return addExp(TensorExp::Kind::kTanhC, e); 1273 if (isa<arith::NegFOp>(def)) 1274 return addExp(TensorExp::Kind::kNegF, e); // no negi in std 1275 if (isa<complex::NegOp>(def)) 1276 return addExp(TensorExp::Kind::kNegC, e); 1277 if (isa<arith::TruncFOp>(def)) 1278 return addExp(TensorExp::Kind::kTruncF, e, v); 1279 if (isa<arith::ExtFOp>(def)) 1280 return addExp(TensorExp::Kind::kExtF, e, v); 1281 if (isa<arith::FPToSIOp>(def)) 1282 return addExp(TensorExp::Kind::kCastFS, e, v); 1283 if (isa<arith::FPToUIOp>(def)) 1284 return addExp(TensorExp::Kind::kCastFU, e, v); 1285 if (isa<arith::SIToFPOp>(def)) 1286 return addExp(TensorExp::Kind::kCastSF, e, v); 1287 if (isa<arith::UIToFPOp>(def)) 1288 return addExp(TensorExp::Kind::kCastUF, e, v); 1289 if (isa<arith::ExtSIOp>(def)) 1290 return addExp(TensorExp::Kind::kCastS, e, v); 1291 if (isa<arith::ExtUIOp>(def)) 1292 return addExp(TensorExp::Kind::kCastU, e, v); 1293 if (isa<arith::IndexCastOp>(def)) 1294 return addExp(TensorExp::Kind::kCastIdx, e, v); 1295 if (isa<arith::TruncIOp>(def)) 1296 return addExp(TensorExp::Kind::kTruncI, e, v); 1297 if (isa<complex::ImOp>(def)) 1298 return addExp(TensorExp::Kind::kCIm, e); 1299 if (isa<complex::ReOp>(def)) 1300 return addExp(TensorExp::Kind::kCRe, e); 1301 if (isa<arith::BitcastOp>(def)) 1302 return addExp(TensorExp::Kind::kBitCast, e, v); 1303 if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) { 1304 if (isAdmissibleBranch(unop, unop.getPresentRegion()) && 1305 isAdmissibleBranch(unop, unop.getAbsentRegion())) 1306 return addExp(TensorExp::Kind::kUnary, e, Value(), def); 1307 } 1308 if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) { 1309 if (isAdmissibleBranch(selop, selop.getRegion())) 1310 return addExp(TensorExp::Kind::kSelect, e, Value(), def); 1311 } 1312 } 1313 } 1314 // Construct binary operations if subexpressions can be built. 1315 // See buildLattices() for an explanation of rejecting certain 1316 // division and shift operations. 1317 if (def->getNumOperands() == 2) { 1318 const auto x = buildTensorExp(op, def->getOperand(0)); 1319 const auto y = buildTensorExp(op, def->getOperand(1)); 1320 if (x.has_value() && y.has_value()) { 1321 const ExprId e0 = *x; 1322 const ExprId e1 = *y; 1323 if (isa<arith::MulFOp>(def)) 1324 return addExp(TensorExp::Kind::kMulF, e0, e1); 1325 if (isa<complex::MulOp>(def)) 1326 return addExp(TensorExp::Kind::kMulC, e0, e1); 1327 if (isa<arith::MulIOp>(def)) 1328 return addExp(TensorExp::Kind::kMulI, e0, e1); 1329 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 1330 return addExp(TensorExp::Kind::kDivF, e0, e1); 1331 if (isa<complex::DivOp>(def) && !maybeZero(e1)) 1332 return addExp(TensorExp::Kind::kDivC, e0, e1); 1333 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 1334 return addExp(TensorExp::Kind::kDivS, e0, e1); 1335 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 1336 return addExp(TensorExp::Kind::kDivU, e0, e1); 1337 if (isa<arith::AddFOp>(def)) 1338 return addExp(TensorExp::Kind::kAddF, e0, e1); 1339 if (isa<complex::AddOp>(def)) 1340 return addExp(TensorExp::Kind::kAddC, e0, e1); 1341 if (isa<arith::AddIOp>(def)) 1342 return addExp(TensorExp::Kind::kAddI, e0, e1); 1343 if (isa<arith::SubFOp>(def)) 1344 return addExp(TensorExp::Kind::kSubF, e0, e1); 1345 if (isa<complex::SubOp>(def)) 1346 return addExp(TensorExp::Kind::kSubC, e0, e1); 1347 if (isa<arith::SubIOp>(def)) 1348 return addExp(TensorExp::Kind::kSubI, e0, e1); 1349 if (isa<arith::AndIOp>(def)) 1350 return addExp(TensorExp::Kind::kAndI, e0, e1); 1351 if (isa<arith::OrIOp>(def)) 1352 return addExp(TensorExp::Kind::kOrI, e0, e1); 1353 if (isa<arith::XOrIOp>(def)) 1354 return addExp(TensorExp::Kind::kXorI, e0, e1); 1355 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 1356 return addExp(TensorExp::Kind::kShrS, e0, e1); 1357 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 1358 return addExp(TensorExp::Kind::kShrU, e0, e1); 1359 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 1360 return addExp(TensorExp::Kind::kShlI, e0, e1); 1361 if (auto ci = dyn_cast<arith::CmpIOp>(def)) { 1362 if (ci.getPredicate() == arith::CmpIPredicate::eq && 1363 ci.getPredicate() == arith::CmpIPredicate::sle && 1364 ci.getPredicate() == arith::CmpIPredicate::sge && 1365 ci.getPredicate() == arith::CmpIPredicate::ule && 1366 ci.getPredicate() == arith::CmpIPredicate::uge) { 1367 // We can not sparsify comparison with equal, this is because 0 <= 0 1368 // yields true, and thus densifies the result. 1369 return std::nullopt; 1370 } 1371 1372 return addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr, 1373 ci.getPredicateAttr()); 1374 } 1375 if (auto cf = dyn_cast<arith::CmpFOp>(def)) { 1376 if (cf.getPredicate() == arith::CmpFPredicate::OEQ && 1377 cf.getPredicate() == arith::CmpFPredicate::OGE && 1378 cf.getPredicate() == arith::CmpFPredicate::OLE && 1379 cf.getPredicate() == arith::CmpFPredicate::ONE && 1380 cf.getPredicate() == arith::CmpFPredicate::UEQ && 1381 cf.getPredicate() == arith::CmpFPredicate::UGE && 1382 cf.getPredicate() == arith::CmpFPredicate::ULE && 1383 cf.getPredicate() == arith::CmpFPredicate::ORD && 1384 cf.getPredicate() == arith::CmpFPredicate::UNO) { 1385 // We can not sparsify comparison with equal, this is because 0 <= 0 1386 // yields true, and thus densifies the result. 1387 return std::nullopt; 1388 } 1389 return addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr, 1390 cf.getPredicateAttr()); 1391 } 1392 if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) { 1393 if (isAdmissibleBranch(binop, binop.getOverlapRegion()) && 1394 (binop.getLeftIdentity() || 1395 isAdmissibleBranch(binop, binop.getLeftRegion())) && 1396 (binop.getRightIdentity() || 1397 isAdmissibleBranch(binop, binop.getRightRegion()))) 1398 return addExp(TensorExp::Kind::kBinary, e0, e1, def); 1399 } 1400 } 1401 } 1402 // Construct ternary operations if subexpressions can be built. 1403 if (def->getNumOperands() == 3) { 1404 const auto x = buildTensorExp(op, def->getOperand(0)); 1405 const auto y = buildTensorExp(op, def->getOperand(1)); 1406 const auto z = buildTensorExp(op, def->getOperand(2)); 1407 if (x.has_value() && y.has_value() && z.has_value()) { 1408 const ExprId e0 = *x; 1409 const ExprId e1 = *y; 1410 if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) { 1411 if (isAdmissibleBranch(redop, redop.getRegion())) 1412 return addExp(TensorExp::Kind::kReduce, e0, e1, def); 1413 } 1414 } 1415 } 1416 // Cannot build. 1417 return std::nullopt; 1418 } 1419 1420 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, 1421 ValueRange vals) { 1422 // Make a clone of overlap region. 1423 Region tmpRegion; 1424 IRMapping mapper; 1425 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); 1426 Block &clonedBlock = tmpRegion.front(); 1427 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); 1428 // Merge cloned block and return yield value. 1429 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1430 rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals); 1431 Value val = clonedYield.getResult(); 1432 rewriter.eraseOp(clonedYield); 1433 rewriter.eraseOp(placeholder); 1434 return val; 1435 } 1436 1437 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, 1438 Operation *op, Value v0) { 1439 if (!v0) 1440 // Empty input value must be propagated. 1441 return Value(); 1442 UnaryOp unop = cast<UnaryOp>(op); 1443 Region &presentRegion = unop.getPresentRegion(); 1444 if (presentRegion.empty()) 1445 // Uninitialized Value() will be interpreted as missing data in the 1446 // output. 1447 return Value(); 1448 return insertYieldOp(rewriter, loc, presentRegion, {v0}); 1449 } 1450 1451 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, 1452 Operation *op, Value v0, Value v1) { 1453 if (!v0 || !v1) 1454 // Empty input values must be propagated. 1455 return Value(); 1456 BinaryOp binop = cast<BinaryOp>(op); 1457 Region &overlapRegion = binop.getOverlapRegion(); 1458 if (overlapRegion.empty()) 1459 // Uninitialized Value() will be interpreted as missing data in the 1460 // output. 1461 return Value(); 1462 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); 1463 } 1464 1465 Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, 1466 Value v1) const { 1467 const auto &expr = exp(e); 1468 switch (expr.kind) { 1469 // Leaf. 1470 case TensorExp::Kind::kTensor: 1471 case TensorExp::Kind::kInvariant: 1472 case TensorExp::Kind::kLoopVar: 1473 case TensorExp::Kind::kSynZero: 1474 llvm_unreachable("unexpected non-op"); 1475 // Unary operations. 1476 case TensorExp::Kind::kAbsF: 1477 return rewriter.create<math::AbsFOp>(loc, v0); 1478 case TensorExp::Kind::kAbsC: { 1479 auto type = cast<ComplexType>(v0.getType()); 1480 auto eltType = cast<FloatType>(type.getElementType()); 1481 return rewriter.create<complex::AbsOp>(loc, eltType, v0); 1482 } 1483 case TensorExp::Kind::kAbsI: 1484 return rewriter.create<math::AbsIOp>(loc, v0); 1485 case TensorExp::Kind::kCeilF: 1486 return rewriter.create<math::CeilOp>(loc, v0); 1487 case TensorExp::Kind::kFloorF: 1488 return rewriter.create<math::FloorOp>(loc, v0); 1489 case TensorExp::Kind::kSqrtF: 1490 return rewriter.create<math::SqrtOp>(loc, v0); 1491 case TensorExp::Kind::kSqrtC: 1492 return rewriter.create<complex::SqrtOp>(loc, v0); 1493 case TensorExp::Kind::kExpm1F: 1494 return rewriter.create<math::ExpM1Op>(loc, v0); 1495 case TensorExp::Kind::kExpm1C: 1496 return rewriter.create<complex::Expm1Op>(loc, v0); 1497 case TensorExp::Kind::kLog1pF: 1498 return rewriter.create<math::Log1pOp>(loc, v0); 1499 case TensorExp::Kind::kLog1pC: 1500 return rewriter.create<complex::Log1pOp>(loc, v0); 1501 case TensorExp::Kind::kSinF: 1502 return rewriter.create<math::SinOp>(loc, v0); 1503 case TensorExp::Kind::kSinC: 1504 return rewriter.create<complex::SinOp>(loc, v0); 1505 case TensorExp::Kind::kTanhF: 1506 return rewriter.create<math::TanhOp>(loc, v0); 1507 case TensorExp::Kind::kTanhC: 1508 return rewriter.create<complex::TanhOp>(loc, v0); 1509 case TensorExp::Kind::kNegF: 1510 return rewriter.create<arith::NegFOp>(loc, v0); 1511 case TensorExp::Kind::kNegC: 1512 return rewriter.create<complex::NegOp>(loc, v0); 1513 case TensorExp::Kind::kNegI: // no negi in std 1514 return rewriter.create<arith::SubIOp>( 1515 loc, 1516 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 1517 rewriter.getZeroAttr(v0.getType())), 1518 v0); 1519 case TensorExp::Kind::kTruncF: 1520 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 1521 case TensorExp::Kind::kExtF: 1522 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 1523 case TensorExp::Kind::kCastFS: 1524 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 1525 case TensorExp::Kind::kCastFU: 1526 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 1527 case TensorExp::Kind::kCastSF: 1528 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 1529 case TensorExp::Kind::kCastUF: 1530 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 1531 case TensorExp::Kind::kCastS: 1532 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 1533 case TensorExp::Kind::kCastU: 1534 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 1535 case TensorExp::Kind::kCastIdx: 1536 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 1537 case TensorExp::Kind::kTruncI: 1538 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 1539 case TensorExp::Kind::kCIm: { 1540 auto type = cast<ComplexType>(v0.getType()); 1541 auto eltType = cast<FloatType>(type.getElementType()); 1542 return rewriter.create<complex::ImOp>(loc, eltType, v0); 1543 } 1544 case TensorExp::Kind::kCRe: { 1545 auto type = cast<ComplexType>(v0.getType()); 1546 auto eltType = cast<FloatType>(type.getElementType()); 1547 return rewriter.create<complex::ReOp>(loc, eltType, v0); 1548 } 1549 case TensorExp::Kind::kBitCast: 1550 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 1551 // Binary operations. 1552 case TensorExp::Kind::kMulF: 1553 return rewriter.create<arith::MulFOp>(loc, v0, v1); 1554 case TensorExp::Kind::kMulC: 1555 return rewriter.create<complex::MulOp>(loc, v0, v1); 1556 case TensorExp::Kind::kMulI: 1557 return rewriter.create<arith::MulIOp>(loc, v0, v1); 1558 case TensorExp::Kind::kDivF: 1559 return rewriter.create<arith::DivFOp>(loc, v0, v1); 1560 case TensorExp::Kind::kDivC: 1561 return rewriter.create<complex::DivOp>(loc, v0, v1); 1562 case TensorExp::Kind::kDivS: 1563 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 1564 case TensorExp::Kind::kDivU: 1565 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 1566 case TensorExp::Kind::kAddF: 1567 return rewriter.create<arith::AddFOp>(loc, v0, v1); 1568 case TensorExp::Kind::kAddC: 1569 return rewriter.create<complex::AddOp>(loc, v0, v1); 1570 case TensorExp::Kind::kAddI: 1571 return rewriter.create<arith::AddIOp>(loc, v0, v1); 1572 case TensorExp::Kind::kSubF: 1573 return rewriter.create<arith::SubFOp>(loc, v0, v1); 1574 case TensorExp::Kind::kSubC: 1575 return rewriter.create<complex::SubOp>(loc, v0, v1); 1576 case TensorExp::Kind::kSubI: 1577 return rewriter.create<arith::SubIOp>(loc, v0, v1); 1578 case TensorExp::Kind::kAndI: 1579 return rewriter.create<arith::AndIOp>(loc, v0, v1); 1580 case TensorExp::Kind::kOrI: 1581 return rewriter.create<arith::OrIOp>(loc, v0, v1); 1582 case TensorExp::Kind::kXorI: 1583 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 1584 case TensorExp::Kind::kShrS: 1585 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 1586 case TensorExp::Kind::kShrU: 1587 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 1588 case TensorExp::Kind::kShlI: 1589 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 1590 case TensorExp::Kind::kCmpI: { 1591 auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr); 1592 return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1); 1593 } 1594 case TensorExp::Kind::kCmpF: { 1595 auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr); 1596 return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1); 1597 } 1598 case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic. 1599 return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(), 1600 {v0}); 1601 case TensorExp::Kind::kUnary: 1602 return buildUnaryPresent(rewriter, loc, expr.op, v0); 1603 case TensorExp::Kind::kSelect: 1604 return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(), 1605 {v0}); 1606 case TensorExp::Kind::kBinary: 1607 return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1); 1608 case TensorExp::Kind::kReduce: { 1609 ReduceOp redOp = cast<ReduceOp>(expr.op); 1610 return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1}); 1611 } 1612 } 1613 llvm_unreachable("unexpected expression kind in build"); 1614 } 1615 1616 } // namespace sparse_tensor 1617 } // namespace mlir 1618