1 //===- LoopEmitter.cpp ----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "LoopEmitter.h" 10 #include "CodegenUtils.h" 11 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/IR/SCF.h" 18 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 22 using namespace mlir; 23 using namespace mlir::sparse_tensor; 24 25 //===----------------------------------------------------------------------===// 26 // File local shorthand macros 27 //===----------------------------------------------------------------------===// 28 29 #define CMPI(p, l, r) \ 30 (builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::p, (l), (r)) \ 31 .getResult()) 32 33 #define C_IDX(v) (constantIndex(builder, loc, (v))) 34 #define YIELD(vs) (builder.create<scf::YieldOp>(loc, (vs))) 35 #define ADDI(lhs, rhs) (builder.create<arith::AddIOp>(loc, (lhs), (rhs))) 36 #define ANDI(lhs, rhs) (builder.create<arith::AndIOp>(loc, (lhs), (rhs))) 37 #define SUBI(lhs, rhs) (builder.create<arith::SubIOp>(loc, (lhs), (rhs))) 38 #define MULI(lhs, rhs) (builder.create<arith::MulIOp>(loc, (lhs), (rhs))) 39 #define REMUI(lhs, rhs) (builder.create<arith::RemUIOp>(loc, (lhs), (rhs))) 40 #define DIVUI(lhs, rhs) (builder.create<arith::DivUIOp>(loc, (lhs), (rhs))) 41 #define SELECT(c, l, r) (builder.create<arith::SelectOp>(loc, (c), (l), (r))) 42 43 //===----------------------------------------------------------------------===// 44 // Debugging utils 45 //===----------------------------------------------------------------------===// 46 47 #ifndef NDEBUG 48 LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder, 49 Location loc, Value memref) { 50 memref = builder.create<memref::CastOp>( 51 loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref); 52 createFuncCall(builder, loc, "printMemrefInd", TypeRange{}, 53 ValueRange{memref}, EmitCInterface::On); 54 } 55 #endif 56 57 //===----------------------------------------------------------------------===// 58 // File local helper functions. 59 //===----------------------------------------------------------------------===// 60 61 // For index reduction loops, since the tensor are sliced into non-continuous 62 // fragments, we need a triple [pLo, pHi, pPtr], in which the pair (pLo, pHi) 63 // specifies the range of the fragment, and pPtr specifies the index of the 64 // corresponding fragment in the child level (i.e., a pointer to the sliced 65 // position array). 66 static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, 67 Level lvl) { 68 auto enc = getSparseTensorEncoding(tensor.getType()); 69 return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl)); 70 } 71 72 static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor, 73 Level lvl) { 74 auto enc = getSparseTensorEncoding(tensor.getType()); 75 return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl)); 76 } 77 78 static bool isIntOrFPZero(Attribute attr) { 79 if (auto f = llvm::dyn_cast<FloatAttr>(attr); f && f.getValue().isZero()) 80 return true; 81 if (auto i = llvm::dyn_cast<IntegerAttr>(attr); i && i.getValue().isZero()) 82 return true; 83 return false; 84 } 85 86 static Value unFoldOpIntResult(OpBuilder &builder, Location loc, 87 OpFoldResult ofr) { 88 if (std::optional<int64_t> i = getConstantIntValue(ofr); i.has_value()) 89 return constantIndex(builder, loc, *i); 90 return cast<Value>(ofr); 91 } 92 93 static Value tryFoldTensors(Value t) { 94 // TODO: this should be done through a folding pass after switching to 95 // `sparse_tensor.iterate`-based sparsification. 96 auto stt = tryGetSparseTensorType(t); 97 auto padOp = t.getDefiningOp<tensor::PadOp>(); 98 if (padOp && stt.has_value() && stt->hasEncoding() && 99 padOp.getSourceType().getEncoding() == stt->getEncoding() && 100 stt->getEncoding().isIdentity()) { 101 // Try fusing padOp with zeros. 102 Attribute padCst; 103 if (matchPattern(padOp.getBody()->getTerminator(), 104 m_Op<tensor::YieldOp>(m_Constant(&padCst))) && 105 isIntOrFPZero(padCst)) { 106 return padOp.getSource(); 107 } 108 } 109 return t; 110 } 111 112 //===----------------------------------------------------------------------===// 113 // Sparse tensor loop emitter class implementations 114 //===----------------------------------------------------------------------===// 115 116 LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput, 117 bool isSparseOut, unsigned numLoops, 118 DependentLvlGetter dimGetter, 119 SparseEmitStrategy emitStrategy) { 120 initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter); 121 } 122 123 void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput, 124 bool isSparseOut, unsigned numLoops, 125 DependentLvlGetter dimGetter, 126 SparseEmitStrategy emitStrategy) { 127 // First initialize the top-level type of the fields. 128 this->loopTag = loopTag; 129 this->hasOutput = hasOutput; 130 this->isSparseOut = isSparseOut; 131 this->emitStrategy = emitStrategy; 132 133 const unsigned numManifestTensors = ts.size(); 134 const unsigned synTensorId = numManifestTensors; 135 const unsigned numTensors = numManifestTensors + 1; 136 // tensors array (len == numManifestTensor). 137 this->tensors.assign(ts.begin(), ts.end()); 138 // Arrays with len == numTensor. 139 this->valBuffer.assign(numTensors, nullptr); 140 this->lvls.resize(numTensors); 141 this->iters.resize(numTensors); 142 this->spIterVals.resize(numTensors); 143 144 // These zeros will be overwritten below, but we need to initialize 145 // them to something since we'll need random-access assignment. 146 this->loopStack.reserve(numLoops); 147 this->loopSeqStack.reserve(numLoops); 148 149 // Index-reduction related fields. 150 this->dependentLvlMap.assign( 151 numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned>>>()); 152 this->sliceMeta.assign( 153 numTensors, std::vector<std::vector<std::pair<Value, unsigned>>>()); 154 this->levelReducedDep.assign(numTensors, std::vector<unsigned>()); 155 156 // Initialize nested types of `TensorId`-indexed fields. 157 for (TensorId tid = 0; tid < numTensors; tid++) { 158 Level lvlRank; 159 if (tid == synTensorId) { 160 // Synthetic tensor (conceptually) is an all-dense tensor with rank equal 161 // to the total number of loops (each level can potentially be mapped to 162 // one of the loop being generated). 163 lvlRank = numLoops; 164 } else { 165 const Value t = tensors[tid]; 166 // a scalar or 0-dimension tensors 167 if (isZeroRankedTensorOrScalar(t.getType())) 168 continue; 169 170 auto rtp = getRankedTensorType(t); 171 const SparseTensorType stt(rtp); 172 lvlRank = stt.getLvlRank(); 173 } 174 175 lvls[tid].resize(lvlRank); 176 iters[tid].resize(lvlRank); 177 spIterVals[tid].resize(lvlRank); 178 loopHighs.assign(numLoops, nullptr); 179 180 // Slice-driven loops related initialization. 181 levelReducedDep[tid].assign(lvlRank, 0); 182 dependentLvlMap[tid].assign( 183 lvlRank, std::vector<std::pair<TensorLevel, unsigned>>()); 184 sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>()); 185 if (dimGetter && !isSynTensor(tid)) { 186 for (Level l = 0; l < lvlRank; l++) { 187 std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l); 188 // Sort the loop by order. 189 llvm::sort(deps, llvm::less_first()); 190 191 dependentLvlMap[tid][l] = std::move(deps); 192 unsigned depends = dependentLvlMap[tid][l].size(); 193 if (depends == 0) 194 continue; 195 sliceMeta[tid][l].reserve(depends); 196 } 197 } 198 } 199 } 200 201 std::unique_ptr<SparseIterator> 202 LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t, 203 Level l) { 204 Value tensor = tensors[t]; 205 auto stt = getSparseTensorType(tensor); 206 auto it = makeSimpleIterator(*lvls[t][l], emitStrategy); 207 208 Value folded = tryFoldTensors(tensor); 209 if (folded != tensor) { 210 auto padOp = tensor.getDefiningOp<tensor::PadOp>(); 211 assert(padOp); 212 if (padOp.getPaddedDims().test(l)) { 213 Value low = unFoldOpIntResult(builder, loc, padOp.getMixedLowPad()[l]); 214 Value high = unFoldOpIntResult(builder, loc, padOp.getMixedHighPad()[l]); 215 auto padIt = makePaddedIterator(std::move(it), low, high, emitStrategy); 216 return padIt; 217 } 218 } 219 220 if (stt.hasEncoding() && stt.getEncoding().isSlice()) { 221 Value offset = genSliceOffset(builder, loc, tensor, l); 222 Value stride = genSliceStride(builder, loc, tensor, l); 223 auto slicedIt = makeSlicedLevelIterator( 224 std::move(it), offset, stride, lvls[t][l]->getSize(), emitStrategy); 225 return slicedIt; 226 } 227 228 return it; 229 } 230 231 void LoopEmitter::initializeLoopEmit( 232 OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater, 233 LoopEmitter::SynTensorBoundSetter synSetter) { 234 235 // For every manifest tensor, set up the values buffer. 236 for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors; 237 t++) { 238 // TODO: this should be done through a folding pass after switching to 239 // `sparse_tensor.iterate`-based sparsification. 240 const Value tensor = tryFoldTensors(tensors[t]); 241 const auto rtp = dyn_cast<RankedTensorType>(tensor.getType()); 242 // Skips only scalar, zero ranked tensor still need to be bufferized and 243 // (probably) filled with zeros by users. 244 if (!rtp) 245 continue; 246 247 auto stt = getSparseTensorType(tensor); 248 const auto shape = rtp.getShape(); 249 250 // Perform the required bufferization. Dense inputs materialize from the 251 // input tensors. Sparse inputs use sparse primitives to obtain the values. 252 // Delegates extra output initialization to clients. 253 bool isOutput = isOutputTensor(t); 254 Type elementType = stt.getElementType(); 255 if (!stt.hasEncoding()) { 256 // Non-annotated dense tensors. 257 BaseMemRefType denseTp = MemRefType::get(shape, elementType); 258 259 // TODO: if we unconditionally use fully dynamic layout here, it breaks 260 // some vectorization passes which requires static stride = 1. 261 // Is it possible to call vectorization pass after bufferization? 262 if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(tensor.getDefiningOp())) 263 denseTp = bufferization::getMemRefTypeWithFullyDynamicLayout(rtp); 264 265 Value denseVal = 266 builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 267 // Dense outputs need special handling. 268 if (isOutput && updater) 269 denseVal = updater(builder, loc, denseVal, tensor); 270 271 valBuffer[t] = denseVal; 272 } else { 273 // Annotated sparse tensors. 274 // We also need the value buffer for all-dense annotated "sparse" 275 // tensors. 276 valBuffer[t] = builder.create<ToValuesOp>(loc, tensor); 277 } 278 } 279 280 // The sparse iterator values will only be available after the loop is 281 // constructed. 282 if (emitStrategy == SparseEmitStrategy::kSparseIterator) 283 return; 284 285 // For every synthetic tensor, set the high bound by calling the callback. 286 if (synSetter) { 287 TensorId synId = getSynTensorId(); 288 for (unsigned i = 0, e = loopHighs.size(); i < e; i++) { 289 Value sz = loopHighs[i] = synSetter(builder, loc, i); 290 auto [stl, it] = makeSynLevelAndIterator(sz, synId, i, emitStrategy); 291 lvls[synId][i] = std::move(stl); 292 iters[synId][i].emplace_back(std::move(it)); 293 } 294 } 295 296 // For every manifest tensor: 297 // * For every level: 298 // * get the positions and coordinates buffers 299 // * get/compute the level-size, which is also used as the upper-bound 300 // on positions. 301 for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors; 302 t++) { 303 // TODO: this should be done through a folding pass after switching to 304 // `sparse_tensor.iterate`-based sparsification. 305 const Value tensor = tryFoldTensors(tensors[t]); 306 const auto rtp = dyn_cast<RankedTensorType>(tensor.getType()); 307 if (!rtp) 308 // Skips only scalar, zero ranked tensor still need to be bufferized and 309 // (probably) filled with zeros by users. 310 continue; 311 312 auto stt = getSparseTensorType(tensor); 313 const Level lvlRank = stt.getLvlRank(); 314 315 // Scan all levels of current tensor. 316 for (Level l = 0; l < lvlRank; l++) { 317 // Find upper bound in current dimension. 318 lvls[t][l] = makeSparseTensorLevel(builder, loc, tensor, t, l); 319 if (!dependentLvlMap[t][l].empty()) 320 continue; 321 322 auto it = makeLevelIterator(builder, loc, t, l); 323 iters[t][l].emplace_back(std::move(it)); 324 } 325 // NOTE: we can also prepare for 0 lvl here in advance, this will hoist 326 // some loop preparation from tensor iteration, but will also (undesirably) 327 // hoist the code ouside if-conditions. 328 } 329 // TODO: avoid treating subsection iterator as a special case. 330 initSubSectIterator(builder, loc); 331 } 332 333 void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) { 334 Value c0 = C_IDX(0); 335 for (TensorId t = 0, e = tensors.size(); t < e; t++) { 336 auto rtp = dyn_cast<RankedTensorType>(tensors[t].getType()); 337 if (!rtp) 338 continue; 339 340 Level lvlRank = SparseTensorType(rtp).getLvlRank(); 341 342 // Compute the dependency reduction order. 343 auto remDepStack = dependentLvlMap; 344 std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder; 345 for (Level lvl = 0; lvl < lvlRank; lvl++) { 346 // Reverse queue into a stack. 347 std::reverse(remDepStack[t][lvl].begin(), remDepStack[t][lvl].end()); 348 for (auto [loop, coeff] : dependentLvlMap[t][lvl]) 349 depRedOrder.emplace_back(std::make_tuple(loop, t, lvl)); 350 } 351 352 if (depRedOrder.empty()) 353 continue; 354 355 std::sort(depRedOrder.begin(), depRedOrder.end(), 356 [](auto &l, auto &r) { return std::get<0>(l) < std::get<0>(r); }); 357 358 SmallVector<SparseIterator *> lastIter(tensors.size(), nullptr); 359 for (auto [loop, t, lvl] : depRedOrder) { 360 std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back(); 361 assert(curDep.first == loop); 362 remDepStack[t][lvl].pop_back(); 363 364 auto lvlIt = makeLevelIterator(builder, loc, t, lvl); 365 const SparseIterator *parent = lastIter[t]; 366 if (!parent && lvl > 0) { 367 if (dependentLvlMap[t][lvl - 1].empty()) { 368 parent = iters[t][lvl - 1].back().get(); 369 } 370 } 371 372 std::unique_ptr<SparseIterator> it; 373 if (!remDepStack[t][lvl].empty()) { 374 // Compute the subsection size. 375 Value size = c0; 376 for (auto [loop, stride] : remDepStack[t][lvl]) { 377 Value idxMax = SUBI(loopHighs[loop], C_IDX(1)); 378 size = ADDI(size, ADDI(MULI(idxMax, C_IDX(stride)), C_IDX(1))); 379 } 380 it = makeNonEmptySubSectIterator(builder, loc, parent, loopHighs[loop], 381 std::move(lvlIt), size, curDep.second, 382 emitStrategy); 383 } else { 384 const SparseIterator &subSectIter = *iters[t][lvl].back(); 385 it = makeTraverseSubSectIterator(builder, loc, subSectIter, *parent, 386 std::move(lvlIt), loopHighs[loop], 387 curDep.second, emitStrategy); 388 } 389 lastIter[t] = it.get(); 390 iters[t][lvl].emplace_back(std::move(it)); 391 } 392 } 393 } 394 395 void LoopEmitter::categorizeIterators( 396 ArrayRef<TensorLevel> tidLvls, SmallVectorImpl<SparseIterator *> &raIters, 397 SmallVectorImpl<SparseIterator *> &spIters) { 398 // Finds out the tensor level that we should use to generate loops. Amongs all 399 // the tensor levels, there is at most one sparse tensor level. 400 for (auto [t, l] : unpackTensorLevelRange(tidLvls)) { 401 SparseIterator *it = &getCurIterator(t, l); 402 if (it->randomAccessible()) 403 raIters.push_back(it); 404 else 405 spIters.push_back(it); 406 } 407 408 std::stable_sort(spIters.begin(), spIters.end(), [](auto lhs, auto rhs) { 409 // AffineUnRed > Affine > Slice > Trivial 410 return static_cast<uint8_t>(lhs->kind) > static_cast<uint8_t>(rhs->kind); 411 }); 412 } 413 414 void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc, 415 ArrayRef<TensorLevel> tidLvls) { 416 // TODO: sort 417 assert(loopSeqStack.size() == loopStack.size()); 418 419 if (emitStrategy != SparseEmitStrategy::kSparseIterator) { 420 // Prepares for all the tensors used in the current loop sequence. 421 for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) { 422 levelReducedDep[tid][lvl]++; 423 prepareLoopOverTensorAtLvl(builder, loc, tid, lvl); 424 } 425 } 426 427 // Universal Index starts from 0. 428 loopSeqStack.emplace_back(C_IDX(0), tidLvls.vec()); 429 } 430 431 void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) { 432 assert(loopSeqStack.size() == loopStack.size() + 1); 433 434 // Depending on whether the slice is resolved or not at current loop sequence, 435 // end them in different ways. 436 for (auto [tid, lvl] : unpackTensorLevelRange(loopSeqStack.back().second)) 437 levelReducedDep[tid][lvl]--; 438 439 loopSeqStack.pop_back(); 440 } 441 442 Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) { 443 switch (a.getKind()) { 444 case AffineExprKind::DimId: { 445 // FIXME: since the one callsite in Sparsification passes in a 446 // level-expression, the `getPosition` must in fact be a `Dimension`. 447 // However, elsewhere we have been lead to expect that `loopIdToOrd` 448 // should be indexed by `LoopId`... 449 const auto loopId = cast<AffineDimExpr>(a).getPosition(); 450 return loopStack[loopId].iv; 451 } 452 case AffineExprKind::Add: { 453 auto binOp = cast<AffineBinaryOpExpr>(a); 454 return ADDI(genAffine(builder, loc, binOp.getLHS()), 455 genAffine(builder, loc, binOp.getRHS())); 456 } 457 case AffineExprKind::Mul: { 458 auto binOp = cast<AffineBinaryOpExpr>(a); 459 return MULI(genAffine(builder, loc, binOp.getLHS()), 460 genAffine(builder, loc, binOp.getRHS())); 461 } 462 case AffineExprKind::Constant: { 463 int64_t c = cast<AffineConstantExpr>(a).getValue(); 464 return C_IDX(c); 465 } 466 default: 467 llvm_unreachable("unexpected affine subscript"); 468 } 469 } 470 471 std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl( 472 OpBuilder &builder, Location loc, SparseIterator &iter, 473 MutableArrayRef<Value> reduc, bool isParallel) { 474 475 // TODO: support dynamic slices. 476 // Uses the first dimension here to build the loop bound (which is also the 477 // biggest range). 478 479 Value step = C_IDX(1); 480 auto [lo, hi] = iter.genForCond(builder, loc); 481 Operation *loop = nullptr; 482 Value iv; 483 if (isParallel) { 484 scf::ParallelOp parOp = 485 builder.create<scf::ParallelOp>(loc, lo, hi, step, reduc); 486 builder.setInsertionPointToStart(parOp.getBody()); 487 assert(parOp.getNumReductions() == reduc.size()); 488 iv = parOp.getInductionVars()[0]; 489 490 // In-place update on the reduction variable vector. 491 // Note that the init vals is not the actual reduction variables but instead 492 // used as a "special handle" to (temporarily) represent them. The 493 // expression on init vals will be moved into scf.reduce and replaced with 494 // the block arguments when exiting the loop (see exitForLoop). This is 495 // needed as we can not build the actual reduction block and get the actual 496 // reduction variable before users fill parallel loop body. 497 for (int i = 0, e = reduc.size(); i < e; i++) 498 reduc[i] = parOp.getInitVals()[i]; 499 loop = parOp; 500 } else { 501 scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, reduc); 502 builder.setInsertionPointToStart(forOp.getBody()); 503 iv = forOp.getInductionVar(); 504 505 // In-place update on the reduction variable vector. 506 assert(forOp.getNumRegionIterArgs() == reduc.size()); 507 for (int i = 0, e = reduc.size(); i < e; i++) 508 reduc[i] = forOp.getRegionIterArg(i); 509 loop = forOp; 510 } 511 assert(loop && iv); 512 513 Value crd = iv; 514 if (!iter.randomAccessible()) { 515 iter.linkNewScope(iv); 516 crd = iter.deref(builder, loc); 517 } else { 518 iter.locate(builder, loc, iv); 519 } 520 521 return {loop, crd}; 522 } 523 524 std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls( 525 OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters, 526 MutableArrayRef<Value> reduc, bool needsUniv) { 527 return genCoIteration(builder, loc, spIters, reduc, 528 needsUniv ? loopSeqStack.back().first : nullptr); 529 } 530 531 bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) { 532 // If we need to co-iterate over two sparse tensors, we need a while loop 533 if (spIters.size() > 1) 534 return false; 535 536 if (spIters.size() == 1) 537 return spIters.front()->iteratableByFor(); 538 539 return true; 540 } 541 542 Region *LoopEmitter::enterCurrentCoIterationCase(OpBuilder &builder, 543 Location loc, 544 I64BitSet caseBit, 545 unsigned caseIdx, 546 MutableArrayRef<Value> reduc) { 547 auto coIterOp = cast<CoIterateOp>(loopStack.back().loop); 548 SmallVector<Attribute> cases(coIterOp.getCases().getAsRange<Attribute>()); 549 cases[caseIdx] = builder.getI64IntegerAttr(caseBit); 550 551 coIterOp.setCasesAttr(builder.getArrayAttr(cases)); 552 Region &caseRegion = coIterOp.getRegion(caseIdx); 553 assert(caseRegion.getBlocks().empty() && 554 "re-initialize the same coiteration case region."); 555 556 // Each block starts with by a list of user-provided iteration arguments. 557 TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes(); 558 // Followed by a list of used coordinates of index type. 559 SmallVector<Type> blockArgTps(coIterOp.getCrdUsedLvls().count(), 560 builder.getIndexType()); 561 562 blockArgTps.append(iterArgsTps.begin(), iterArgsTps.end()); 563 // Ends with a set of iterators that defines the actually iteration space. 564 for (auto i : caseBit.bits()) { 565 blockArgTps.push_back( 566 cast<IterSpaceType>(coIterOp.getIterSpaces()[i].getType()) 567 .getIteratorType()); 568 } 569 SmallVector<Location> locs(blockArgTps.size(), loc); 570 caseRegion.emplaceBlock().addArguments(blockArgTps, locs); 571 572 // Entering the new region scope, updating the SSA chain. 573 builder.setInsertionPointToStart(&caseRegion.front()); 574 // Update the coordinates. 575 loopStack.back().iv = coIterOp.getCrds(caseIdx).front(); 576 // Updates loop iteration arguments. 577 ValueRange iterArgs = coIterOp.getRegionIterArgs(caseIdx); 578 llvm::copy(iterArgs, reduc.begin()); 579 // Updates sparse iterator values. 580 ValueRange iters = coIterOp.getRegionIterators(caseIdx); 581 ArrayRef<TensorLevel> tidLvls = loopStack.back().tidLvls; 582 for (auto [i, tl] : llvm::enumerate(unpackTensorLevelRange(tidLvls))) { 583 if (caseBit[i]) { 584 spIterVals[tl.first][tl.second] = iters.front(); 585 iters = iters.drop_front(); 586 } else { 587 spIterVals[tl.first][tl.second] = nullptr; 588 } 589 } 590 // Must have consumed all iterator SSA values. 591 assert(iters.empty()); 592 return &caseRegion; 593 } 594 595 Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls( 596 OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls, 597 unsigned numCases, MutableArrayRef<Value> reduc, bool tryParallel, 598 bool needsUniv) { 599 // TODO: Argument `numCases` only used when generating iterator-based sparse 600 // loops. Simplify the code upon feature complete. 601 // TODO: handle coiteration with sparse iterator. 602 if (emitStrategy == SparseEmitStrategy::kSparseIterator) { 603 if (tidLvls.size() == 1) { 604 auto [tid, lvl] = unpackTensorLevel(tidLvls.front()); 605 Value t = tensors[tid]; 606 607 // Extract and iterate over the iteration space. 608 ExtractIterSpaceOp extractSpaceOp = 609 lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t) 610 : builder.create<ExtractIterSpaceOp>( 611 loc, t, spIterVals[tid][lvl - 1], lvl); 612 613 IterateOp iterOp = builder.create<IterateOp>( 614 loc, extractSpaceOp.getExtractedSpace(), reduc); 615 spIterVals[tid][lvl] = iterOp.getIterator(); 616 617 // Update the reduction varaibles. 618 llvm::copy(iterOp.getRegionIterArgs(), reduc.begin()); 619 // Set the insertion point to loop body. 620 builder.setInsertionPointToStart(iterOp.getBody()); 621 loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(), 622 iterOp.getCrds().front(), loopTag); 623 return iterOp; 624 } 625 626 // CoIteration Loops. 627 SmallVector<Value> spaces; 628 for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) { 629 Value t = tensors[tid]; 630 ExtractIterSpaceOp extractSpaceOp = 631 lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t) 632 : builder.create<ExtractIterSpaceOp>( 633 loc, t, spIterVals[tid][lvl - 1], lvl); 634 spaces.push_back(extractSpaceOp.getExtractedSpace()); 635 } 636 auto coIterOp = builder.create<CoIterateOp>(loc, spaces, reduc, numCases); 637 // The CoIterationOp does not have insertion block nor induction variable. 638 // TODO: the `struct LoopInfo` should be simplied after full migration. 639 loopStack.emplace_back(tidLvls, coIterOp, /*insertion block*/ nullptr, 640 /*induction variable*/ nullptr, loopTag); 641 return coIterOp; 642 } 643 644 // TODO: support multiple return on parallel for? 645 tryParallel = tryParallel && reduc.size() <= 1; 646 647 SmallVector<SparseIterator *> raIters; 648 SmallVector<SparseIterator *> spIters; 649 categorizeIterators(tidLvls, raIters, spIters); 650 651 // Only when there is at least one sparse conditions, do we really need the 652 // universal index. 653 // TODO: Maybe we should instead requires merger to pass in a valid value at 654 // the first place instead of adjusting it in LoopEmitter? 655 needsUniv = !spIters.empty() && needsUniv; 656 // The TensorLevel used for loop conditions. 657 // If there is any sparse level, we need to use the sparse condition. 658 // If all levels are dense, we can pick arbitrary one (dense slice-driven loop 659 // can be generated using a simple ForOp as well). 660 Operation *l = nullptr; 661 Value iv = nullptr; 662 SmallVector<TensorLevel> tls; 663 664 // Generates loops differently depending on whether we need a slice-driven 665 // loop or a simple level traversal loop. 666 if (shouldIteratedByForLoop(spIters) && !needsUniv) { 667 assert(spIters.size() <= 1); 668 SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front(); 669 std::tie(l, iv) = 670 emitForLoopOverTensorAtLvl(builder, loc, it, reduc, tryParallel); 671 tls.push_back(makeTensorLevel(it.tid, it.lvl)); 672 } else { 673 for (auto *it : spIters) { 674 tls.push_back(makeTensorLevel(it->tid, it->lvl)); 675 } 676 677 if (needsUniv) 678 for (auto *it : raIters) 679 tls.push_back(makeTensorLevel(it->tid, it->lvl)); 680 681 std::tie(l, iv) = 682 emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv); 683 } 684 685 // Enter dense tensor levels. 686 for (SparseIterator *it : raIters) 687 it->locate(builder, loc, iv); 688 689 // NOTE: we can also prepare for next dim here in advance 690 // Pushes the loop into stack. 691 loopStack.emplace_back(tls, l, builder.getInsertionBlock(), iv, loopTag); 692 return l; 693 } 694 695 void LoopEmitter::locateLvlAtAffineAddress(OpBuilder &builder, Location loc, 696 TensorLevel tidLvl, 697 AffineExpr lvlExpr) { 698 auto [tid, lvl] = unpackTensorLevel(tidLvl); 699 700 const SparseIterator *parent = 701 lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get(); 702 auto &it = getCurIterator(tid, lvl); 703 it.genInit(builder, loc, parent); 704 705 assert(it.kind == IterKind::kTrivial && it.randomAccessible()); 706 Value lvlCrd = genAffine(builder, loc, lvlExpr); 707 it.locate(builder, loc, lvlCrd); 708 } 709 710 void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, 711 TensorId tid, Level lvl) { 712 // if this is the first level, there is no parent iterator for the current 713 // iterator. 714 // If the current iterator is a subsection-based iterator, the parent iterator 715 // is memorized by the iterator. 716 bool hasParent = lvl == 0 || !dependentLvlMap[tid][lvl].empty(); 717 718 const SparseIterator *parent = 719 hasParent ? nullptr : iters[tid][lvl - 1].back().get(); 720 auto &it = getCurIterator(tid, lvl); 721 it.genInit(builder, loc, parent); 722 723 // Locates the randon accessible iterator to 0. 724 if (it.randomAccessible()) 725 it.locate(builder, loc, C_IDX(0)); 726 } 727 728 void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, 729 MutableArrayRef<Value> reduc) { 730 const LoopInfo &loopInfo = loopStack.back(); 731 if (emitStrategy == SparseEmitStrategy::kSparseIterator) { 732 auto iterateOp = llvm::cast<IterateOp>(loopInfo.loop); 733 assert(reduc.size() == iterateOp.getNumResults()); 734 rewriter.create<sparse_tensor::YieldOp>(loc, reduc); 735 // Exit the loop. 736 rewriter.setInsertionPointAfter(iterateOp); 737 // In-place update reduction variables. 738 llvm::copy(iterateOp.getResults(), reduc.begin()); 739 return; 740 } 741 if (auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) { 742 if (!reduc.empty()) { 743 assert(reduc.size() == forOp.getNumResults()); 744 rewriter.create<scf::YieldOp>(loc, reduc); 745 } 746 // Exit the loop. 747 rewriter.setInsertionPointAfter(forOp); 748 // In-place update reduction variables. 749 llvm::copy(forOp.getResults(), reduc.begin()); 750 } else { 751 auto parOp = llvm::cast<scf::ParallelOp>(loopInfo.loop); 752 if (!reduc.empty()) { 753 assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1); 754 Operation *redExp = reduc.front().getDefiningOp(); 755 // Reduction expression should have no use. 756 assert(redExp->getUses().empty()); 757 // This must be a binary operation. 758 // NOTE: This is users' responsibility to ensure the operation are 759 // commutative. 760 assert(redExp->getNumOperands() == 2 && redExp->getNumResults() == 1); 761 762 Value redVal = parOp.getInitVals().front(); 763 Value curVal; 764 if (redExp->getOperand(0) == redVal) 765 curVal = redExp->getOperand(1); 766 else if (redExp->getOperand(1) == redVal) 767 curVal = redExp->getOperand(0); 768 // One of the operands must be the init value (which is also the 769 // previous reduction value). 770 assert(curVal); 771 #ifndef NDEBUG 772 // The reduction expression should be the only user of the reduction val 773 // inside the parallel for. 774 unsigned numUsers = 0; 775 for (Operation *op : redVal.getUsers()) { 776 if (op->getParentOp() == parOp) 777 numUsers++; 778 } 779 assert(numUsers == 1); 780 #endif // NDEBUG 781 782 rewriter.setInsertionPointAfter(redExp); 783 auto redOp = rewriter.create<scf::ReduceOp>(loc, curVal); 784 // Attach to the reduction op. 785 Block *redBlock = &redOp.getReductions().front().front(); 786 rewriter.setInsertionPointToEnd(redBlock); 787 Operation *newRed = rewriter.clone(*redExp); 788 // Replaces arguments of the reduction expression by using the block 789 // arguments from scf.reduce. 790 rewriter.modifyOpInPlace( 791 newRed, [&]() { newRed->setOperands(redBlock->getArguments()); }); 792 // Erases the out-dated reduction expression. 793 rewriter.eraseOp(redExp); 794 rewriter.setInsertionPointToEnd(redBlock); 795 rewriter.create<scf::ReduceReturnOp>(loc, newRed->getResult(0)); 796 } 797 rewriter.setInsertionPointAfter(parOp); 798 // In-place update reduction variables. 799 for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++) 800 reduc[i] = parOp.getResult(i); 801 } 802 } 803 804 void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc, 805 MutableArrayRef<Value> reduc) { 806 const LoopInfo &loopInfo = loopStack.back(); 807 auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop); 808 Value iv = loopInfo.iv; 809 Value one = C_IDX(1); 810 811 // Finalize the induction. Note that the induction could be performed 812 // in the individual if-branches to avoid re-evaluating the conditions. 813 // However, that would result in a rather elaborate forest of yield 814 // instructions during code generation. Moreover, performing the induction 815 // after the if-statements more closely resembles code generated by TACO. 816 SmallVector<Value> operands; 817 ValueRange whileRes = whileOp.getResults(); 818 819 for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.tidLvls)) { 820 SparseIterator &it = getCurIterator(tid, lvl); 821 if (!it.randomAccessible()) { 822 // Forward the sparse iterator. 823 Value cmp = CMPI(eq, it.getCrd(), iv); 824 it.forwardIf(builder, loc, cmp); 825 operands.append(it.getCursor().begin(), it.getCursor().end()); 826 // const Value newPos = whileOp->getResult(o++); 827 // Following loops continue iteration from the break point of the 828 // current while loop. 829 whileRes = it.linkNewScope(whileRes); 830 } else { 831 // Make sure randomly accessible (dense) iterator is set to the right 832 // position according to the universal index. 833 Value uniIdx = whileOp.getResults().back(); 834 it.locate(builder, loc, uniIdx); 835 } 836 } 837 838 // Reduction value from users. 839 for (auto &i : reduc) { 840 operands.push_back(i); 841 // Update user reduction variables. 842 i = whileRes.front(); 843 whileRes = whileRes.drop_front(); 844 } 845 846 // An (optional) universal index. 847 if (operands.size() < whileOp.getNumResults()) { 848 assert(operands.size() + 1 == whileOp.getNumResults()); 849 // The last one is the universial index. 850 operands.push_back(ADDI(iv, one)); 851 // update the loop starting point of current loop sequence 852 loopSeqStack.back().first = whileOp->getResults().back(); 853 } 854 855 if (!operands.empty()) 856 YIELD(operands); 857 858 builder.setInsertionPointAfter(whileOp); 859 } 860 861 void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc, 862 MutableArrayRef<Value> reduc) { 863 // Clean up the values, it would help use to discover potential bug at a 864 // earlier stage (instead of silently using a wrong value). 865 const LoopInfo &loopInfo = loopStack.back(); 866 if (emitStrategy == SparseEmitStrategy::kSparseIterator) { 867 Operation *p = loopInfo.loop; 868 if (isa<IterateOp>(p)) 869 rewriter.create<sparse_tensor::YieldOp>(loc, reduc); 870 871 // Exit the loop. 872 rewriter.setInsertionPointAfter(p); 873 // In-place update reduction variables. 874 llvm::copy(p->getResults(), reduc.begin()); 875 loopStack.pop_back(); 876 return; 877 } 878 879 // Sets the insertion point to the right position. 880 rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock); 881 if (!loopInfo.userCodeBlock->empty() && 882 llvm::isa<scf::YieldOp>(&loopInfo.userCodeBlock->back())) { 883 // scf::While/For inserts an implicit yield op when there is no loop 884 // iter args. In this case, we need to insert the code before the yield. 885 assert(loopInfo.userCodeBlock->back().getNumResults() == 0); 886 rewriter.setInsertionPoint(&loopInfo.userCodeBlock->back()); 887 } 888 889 if (llvm::isa<scf::WhileOp>(loopInfo.loop)) { 890 exitWhileLoop(rewriter, loc, reduc); 891 } else { 892 exitForLoop(rewriter, loc, reduc); 893 } 894 895 assert(loopStack.size() == loopSeqStack.size()); 896 loopStack.pop_back(); 897 } 898 899 //===----------------------------------------------------------------------===// 900 // Loop generation utils 901 //===----------------------------------------------------------------------===// 902 903 std::pair<Operation *, Value> sparse_tensor::genCoIteration( 904 OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters, 905 MutableArrayRef<Value> reduc, Value uniIdx, bool userReducFirst) { 906 // NOTE: the slice driven tensor-related reduction variable must 907 // appear before normal tensors. 908 909 // The set of induction variables for the while loop. 910 SmallVector<Value> ivs; 911 912 // TODO: remove the flag after full migration. Currently 913 // `sparse_tensor.coiterate` operation (must) put user provided reduction 914 // values at the front of the block list, while direct sparsification to scf 915 // loops put them at the end. 916 if (userReducFirst) 917 ivs.append(reduc.begin(), reduc.end()); 918 919 // Construct the while-loop with a parameter for each coordinate. 920 for (SparseIterator *it : spIters) { 921 ValueRange itVals = it->getCursor(); 922 ivs.append(itVals.begin(), itVals.end()); 923 } 924 925 if (!userReducFirst) 926 ivs.append(reduc.begin(), reduc.end()); 927 928 // Update universal index. 929 if (uniIdx) 930 ivs.push_back(uniIdx); 931 932 // Ensures all operands are valid. 933 assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; })); 934 TypeRange types = ValueRange(ivs).getTypes(); 935 auto whileOp = builder.create<scf::WhileOp>(loc, types, ivs); 936 937 SmallVector<Location> locs(types.size(), loc); 938 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); 939 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); 940 941 // Generates loop conditions. 942 builder.setInsertionPointToStart(before); 943 ValueRange bArgs = before->getArguments(); 944 Value whileCond = nullptr; // bool values for loop condition. 945 946 for (SparseIterator *it : spIters) { 947 auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs); 948 whileCond = !whileCond ? cond : ANDI(whileCond, cond); 949 bArgs = remArgs; 950 } 951 // The remaining block arguments are user-provided reduction values and an 952 // optional universal index. Make sure their sizes match. 953 assert(bArgs.size() == reduc.size() + (uniIdx ? 1 : 0)); 954 builder.create<scf::ConditionOp>(loc, whileCond, before->getArguments()); 955 956 // Generates loop body. 957 builder.setInsertionPointToStart(after); 958 ValueRange aArgs = after->getArguments(); 959 // Since some LoopCondKind might need extra checks to filter out invalid 960 // iterations, we maintains another array to hold the iteration arguments to 961 // yield if the checks fails. 962 SmallVector<Value> nextArgs(aArgs.begin(), aArgs.end()); 963 964 for (SparseIterator *it : spIters) { 965 aArgs = it->linkNewScope(aArgs); 966 // Dereference the iterator to cache the coordinate. 967 it->deref(builder, loc); 968 } 969 970 // In-place update on reduction variable. 971 for (unsigned i = 0, e = reduc.size(); i < e; i++) 972 reduc[i] = aArgs[i]; 973 974 Value min; 975 // Finds the minimum coordinate 976 if (!uniIdx) { 977 for (SparseIterator *it : spIters) { 978 if (min) { 979 Value cmp = CMPI(ult, it->getCrd(), min); 980 min = SELECT(cmp, it->getCrd(), min); 981 } else { 982 min = it->getCrd(); 983 } 984 } 985 } else { 986 // Otherwise, universal index is the minimal pos. 987 min = whileOp.getAfterArguments().back(); 988 } 989 990 return {whileOp, min}; 991 } 992 993 #undef CMPI 994 #undef C_IDX 995 #undef YIELD 996 #undef ADDI 997 #undef ANDI 998 #undef SUBI 999 #undef MULI 1000 #undef SELECT 1001