1 2 #include "Utils/CodegenUtils.h" 3 #include "Utils/LoopEmitter.h" 4 #include "Utils/SparseTensorIterator.h" 5 6 #include "mlir/Dialect/MemRef/IR/MemRef.h" 7 #include "mlir/Dialect/SCF/IR/SCF.h" 8 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 9 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 10 #include "mlir/Transforms/DialectConversion.h" 11 12 using namespace mlir; 13 using namespace mlir::sparse_tensor; 14 15 /// Assert that the given value range contains a single value and return it. 16 static Value getSingleValue(ValueRange values) { 17 assert(values.size() == 1 && "expected single value"); 18 return values.front(); 19 } 20 21 static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl, 22 SmallVectorImpl<Type> &fields) { 23 // Position and coordinate buffer in the sparse structure. 24 if (enc.getLvlType(lvl).isWithPosLT()) 25 fields.push_back(enc.getPosMemRefType()); 26 if (enc.getLvlType(lvl).isWithCrdLT()) 27 fields.push_back(enc.getCrdMemRefType()); 28 // One index for shape bound (result from lvlOp). 29 fields.push_back(IndexType::get(enc.getContext())); 30 } 31 32 static std::optional<LogicalResult> 33 convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) { 34 35 auto idxTp = IndexType::get(itSp.getContext()); 36 for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++) 37 convertLevelType(itSp.getEncoding(), l, fields); 38 39 // Two indices for lower and upper bound (we only need one pair for the last 40 // iteration space). 41 fields.append({idxTp, idxTp}); 42 return success(); 43 } 44 45 static std::optional<LogicalResult> 46 convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) { 47 // The actually Iterator Values (that are updated every iteration). 48 auto idxTp = IndexType::get(itTp.getContext()); 49 // TODO: handle batch dimension. 50 assert(itTp.getEncoding().getBatchLvlRank() == 0); 51 if (!itTp.isUnique()) { 52 // Segment high for non-unique iterator. 53 fields.push_back(idxTp); 54 } 55 fields.push_back(idxTp); 56 return success(); 57 } 58 59 static ValueRange 60 genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, 61 Value loopCrd, 62 ArrayRef<std::unique_ptr<SparseIterator>> iters, 63 ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks, 64 ArrayRef<Value> userReduc) { 65 if (newBlocks.empty()) 66 return userReduc; 67 68 // The current branch that we are handling. 69 Block *newBlock = newBlocks.front(); 70 Block *oldBlock = oldBlocks.front(); 71 Value casePred = constantI1(rewriter, loc, true); 72 I64BitSet caseBits = 73 op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber()); 74 for (unsigned i : caseBits.bits()) { 75 SparseIterator *it = iters[i].get(); 76 Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 77 it->getCrd(), loopCrd); 78 casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred); 79 } 80 scf::IfOp ifOp = rewriter.create<scf::IfOp>( 81 loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true); 82 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 83 84 // Erase the empty block. 85 rewriter.eraseBlock(&ifOp.getThenRegion().front()); 86 // Set up block arguments: user-provided values -> loop coord -> iterators. 87 SmallVector<Value> blockArgs(userReduc); 88 blockArgs.push_back(loopCrd); 89 for (unsigned idx : caseBits.bits()) 90 llvm::append_range(blockArgs, iters[idx]->getCursor()); 91 92 // Map the old block arguments, because the dialect conversion driver does 93 // not immediately perform SSA value replacements. This function is still 94 // seeing the old uses. 95 IRMapping mapping; 96 for (auto [from, to] : llvm::zip_equal(oldBlock->getArguments(), blockArgs)) { 97 mapping.map(from, to); 98 } 99 100 // Clone the region, we can not erase the region now because the same region 101 // might be a subcase for multiple lattice point. 102 rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(), 103 ifOp.getThenRegion().begin(), mapping); 104 // Remove the block arguments, they were already replaced via `mapping`. 105 ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size()); 106 107 // replace sparse_tensor::YieldOp -> scf::YieldOp 108 auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back()); 109 ValueRange yields = spY.getResults(); 110 rewriter.eraseOp(spY); 111 rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front()); 112 rewriter.create<scf::YieldOp>(loc, yields); 113 114 // Generates remaining case recursively. 115 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 116 ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters, 117 newBlocks.drop_front(), 118 oldBlocks.drop_front(), userReduc); 119 if (!res.empty()) 120 rewriter.create<scf::YieldOp>(loc, res); 121 122 rewriter.setInsertionPointAfter(ifOp); 123 return ifOp.getResults(); 124 } 125 126 static ValueRange genLoopWithIterator( 127 PatternRewriter &rewriter, Location loc, SparseIterator *it, 128 ValueRange reduc, 129 function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc, 130 Region &loopBody, SparseIterator *it, 131 ValueRange reduc)> 132 bodyBuilder) { 133 if (it->iteratableByFor()) { 134 auto [lo, hi] = it->genForCond(rewriter, loc); 135 Value step = constantIndex(rewriter, loc, 1); 136 scf::ForOp forOp = rewriter.create<scf::ForOp>( 137 loc, lo, hi, step, reduc, 138 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { 139 // Empty builder function to ensure that no terminator is created. 140 }); 141 { 142 OpBuilder::InsertionGuard guard(rewriter); 143 it->linkNewScope(forOp.getInductionVar()); 144 rewriter.setInsertionPointToStart(forOp.getBody()); 145 SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(), 146 it, forOp.getRegionIterArgs()); 147 148 rewriter.setInsertionPointToEnd(forOp.getBody()); 149 rewriter.create<scf::YieldOp>(loc, ret); 150 } 151 return forOp.getResults(); 152 } 153 154 SmallVector<Value> ivs(reduc); 155 llvm::append_range(ivs, it->getCursor()); 156 157 TypeRange types = ValueRange(ivs).getTypes(); 158 auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs); 159 { 160 OpBuilder::InsertionGuard guard(rewriter); 161 // Generates loop conditions. 162 SmallVector<Location> l(types.size(), loc); 163 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l); 164 rewriter.setInsertionPointToStart(before); 165 ValueRange bArgs = before->getArguments(); 166 auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs); 167 rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments()); 168 169 // Delegates loop body generation. 170 Region &dstRegion = whileOp.getAfter(); 171 Block *after = rewriter.createBlock(&dstRegion, {}, types, l); 172 ValueRange aArgs = whileOp.getAfterArguments(); 173 it->linkNewScope(aArgs.drop_front(reduc.size())); 174 aArgs = aArgs.take_front(reduc.size()); 175 176 rewriter.setInsertionPointToStart(after); 177 SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs); 178 rewriter.setInsertionPointToEnd(after); 179 180 // Forward loops 181 SmallVector<Value> yields; 182 llvm::append_range(yields, ret); 183 llvm::append_range(yields, it->forward(rewriter, loc)); 184 rewriter.create<scf::YieldOp>(loc, yields); 185 } 186 return whileOp.getResults().drop_front(it->getCursor().size()); 187 } 188 189 namespace { 190 191 /// Sparse codegen rule for number of entries operator. 192 class ExtractIterSpaceConverter 193 : public OpConversionPattern<ExtractIterSpaceOp> { 194 public: 195 using OpConversionPattern::OpConversionPattern; 196 LogicalResult 197 matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor, 198 ConversionPatternRewriter &rewriter) const override { 199 Location loc = op.getLoc(); 200 201 // Construct the iteration space. 202 SparseIterationSpace space(loc, rewriter, 203 getSingleValue(adaptor.getTensor()), 0, 204 op.getLvlRange(), adaptor.getParentIter()); 205 206 SmallVector<Value> result = space.toValues(); 207 rewriter.replaceOpWithMultiple(op, {result}); 208 return success(); 209 } 210 }; 211 212 /// Sparse codegen rule for number of entries operator. 213 class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> { 214 public: 215 using OpConversionPattern::OpConversionPattern; 216 LogicalResult 217 matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor, 218 ConversionPatternRewriter &rewriter) const override { 219 Location loc = op.getLoc(); 220 Value pos = adaptor.getIterator().back(); 221 Value valBuf = 222 rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor())); 223 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos); 224 return success(); 225 } 226 }; 227 228 class SparseIterateOpConverter : public OpConversionPattern<IterateOp> { 229 public: 230 using OpConversionPattern::OpConversionPattern; 231 LogicalResult 232 matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor, 233 ConversionPatternRewriter &rewriter) const override { 234 if (!op.getCrdUsedLvls().empty()) 235 return rewriter.notifyMatchFailure( 236 op, "non-empty coordinates list not implemented."); 237 238 Location loc = op.getLoc(); 239 240 auto iterSpace = SparseIterationSpace::fromValues( 241 op.getIterSpace().getType(), adaptor.getIterSpace(), 0); 242 243 std::unique_ptr<SparseIterator> it = 244 iterSpace.extractIterator(rewriter, loc); 245 246 SmallVector<Value> ivs; 247 for (ValueRange inits : adaptor.getInitArgs()) 248 llvm::append_range(ivs, inits); 249 250 // Type conversion on iterate op block. 251 unsigned numOrigArgs = op.getBody()->getArgumentTypes().size(); 252 TypeConverter::SignatureConversion signatureConversion(numOrigArgs); 253 if (failed(typeConverter->convertSignatureArgs( 254 op.getBody()->getArgumentTypes(), signatureConversion))) 255 return rewriter.notifyMatchFailure( 256 op, "failed to convert iterate region argurment types"); 257 258 Block *block = rewriter.applySignatureConversion( 259 op.getBody(), signatureConversion, getTypeConverter()); 260 ValueRange ret = genLoopWithIterator( 261 rewriter, loc, it.get(), ivs, 262 [block](PatternRewriter &rewriter, Location loc, Region &loopBody, 263 SparseIterator *it, ValueRange reduc) -> SmallVector<Value> { 264 SmallVector<Value> blockArgs(reduc); 265 // TODO: Also appends coordinates if used. 266 // blockArgs.push_back(it->deref(rewriter, loc)); 267 llvm::append_range(blockArgs, it->getCursor()); 268 269 Block *dstBlock = &loopBody.getBlocks().front(); 270 rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(), 271 blockArgs); 272 auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back()); 273 // We can not use ValueRange as the operation holding the values will 274 // be destoryed. 275 SmallVector<Value> result(yield.getResults()); 276 rewriter.eraseOp(yield); 277 return result; 278 }); 279 280 rewriter.replaceOp(op, ret); 281 return success(); 282 } 283 }; 284 285 class SparseCoIterateOpConverter : public OpConversionPattern<CoIterateOp> { 286 using OpConversionPattern::OpConversionPattern; 287 288 LogicalResult 289 matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor, 290 ConversionPatternRewriter &rewriter) const override { 291 assert(op.getSpaceDim() == 1 && "Not implemented"); 292 Location loc = op.getLoc(); 293 294 I64BitSet denseBits(0); 295 for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes())) 296 if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT)) 297 denseBits.set(idx); 298 299 // If there exists a case that only contains dense spaces. I.e., case 300 // bits is a subset of dense bits, or when there is a full empty case (due 301 // to complements), we need a universal pointer to forward the coiteration 302 // loop. 303 bool needUniv = 304 any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) { 305 // A case for complement. 306 if (caseBits.count() == 0) 307 return true; 308 // An all-dense case. 309 return caseBits.isSubSetOf(denseBits); 310 }); 311 assert(!needUniv && "Not implemented"); 312 (void)needUniv; 313 314 SmallVector<Block *> newBlocks; 315 DenseMap<Block *, Block *> newToOldBlockMap; 316 for (Region ®ion : op.getCaseRegions()) { 317 // Do a one-shot type conversion on all region blocks, since the same 318 // region might be used multiple time. 319 Block *block = ®ion.getBlocks().front(); 320 TypeConverter::SignatureConversion blockTypeMapping( 321 block->getArgumentTypes().size()); 322 if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), 323 blockTypeMapping))) { 324 return rewriter.notifyMatchFailure( 325 op, "failed to convert coiterate region argurment types"); 326 } 327 328 newBlocks.push_back(rewriter.applySignatureConversion( 329 block, blockTypeMapping, getTypeConverter())); 330 newToOldBlockMap[newBlocks.back()] = block; 331 } 332 333 SmallVector<SparseIterationSpace> spaces; 334 SmallVector<std::unique_ptr<SparseIterator>> iters; 335 for (auto [spaceTp, spaceVals] : llvm::zip_equal( 336 op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) { 337 // TODO: do we really need tid? 338 spaces.push_back(SparseIterationSpace::fromValues( 339 cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0)); 340 // Extract the iterator. 341 iters.push_back(spaces.back().extractIterator(rewriter, loc)); 342 } 343 344 auto getFilteredIters = [&iters](I64BitSet caseBits) { 345 // Retrives a vector of pointers to the iterators used in the case. 346 SmallVector<SparseIterator *> validIters; 347 for (auto idx : caseBits.bits()) 348 validIters.push_back(iters[idx].get()); 349 return validIters; 350 }; 351 352 // Get a flattened user-provided loop reduction values. 353 SmallVector<Value> userReduc; 354 for (ValueRange r : adaptor.getInitArgs()) 355 llvm::append_range(userReduc, r); 356 357 // TODO: we need to sort the cases such that they appears in lexical order. 358 // Although sparsification always generates cases in that order, it might 359 // not be the case for human-written code. 360 361 // Generates a loop sequence, one loop per case. 362 for (auto [r, caseBits] : 363 llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) { 364 assert(caseBits.count() > 0 && "Complement space not implemented"); 365 366 // Retrives a vector of pointers to the iterators used in the case. 367 SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits); 368 369 if (validIters.size() > 1) { 370 auto [loop, loopCrd] = 371 genCoIteration(rewriter, loc, validIters, userReduc, 372 /*uniIdx=*/nullptr, /*userReducFirst=*/true); 373 374 // 1st. find all the cases that is a strict subset of the current case 375 // condition, for which we generate one branch per case inside the loop. 376 // The subcases are never empty, it must contains at least the current 377 // region itself. 378 // TODO: these cases should be sorted. 379 SmallVector<Region *> subCases = 380 op.getSubCasesOf(r->getParent()->getRegionNumber()); 381 SmallVector<Block *> newBlocks, oldBlocks; 382 for (Region *r : subCases) { 383 newBlocks.push_back(&r->front()); 384 oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]); 385 } 386 assert(!subCases.empty()); 387 388 ValueRange res = genCoIterateBranchNest( 389 rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc); 390 391 SmallVector<Value> nextIterYields(res); 392 // 2nd. foward the loop. 393 for (SparseIterator *it : validIters) { 394 Value cmp = rewriter.create<arith::CmpIOp>( 395 loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd); 396 it->forwardIf(rewriter, loc, cmp); 397 llvm::append_range(nextIterYields, it->getCursor()); 398 } 399 rewriter.create<scf::YieldOp>(loc, nextIterYields); 400 401 // Exit the loop, relink the iterator SSA value. 402 rewriter.setInsertionPointAfter(loop); 403 ValueRange iterVals = loop->getResults().drop_front(userReduc.size()); 404 for (SparseIterator *it : validIters) 405 iterVals = it->linkNewScope(iterVals); 406 assert(iterVals.empty()); 407 408 ValueRange curResult = loop->getResults().take_front(userReduc.size()); 409 userReduc.assign(curResult.begin(), curResult.end()); 410 } else { 411 // This is a simple iteration loop. 412 assert(caseBits.count() == 1); 413 414 Block *block = r; 415 ValueRange curResult = genLoopWithIterator( 416 rewriter, loc, validIters.front(), userReduc, 417 /*bodyBuilder=*/ 418 [block](PatternRewriter &rewriter, Location loc, Region &dstRegion, 419 SparseIterator *it, 420 ValueRange reduc) -> SmallVector<Value> { 421 SmallVector<Value> blockArgs(reduc); 422 blockArgs.push_back(it->deref(rewriter, loc)); 423 llvm::append_range(blockArgs, it->getCursor()); 424 425 Block *dstBlock = &dstRegion.getBlocks().front(); 426 rewriter.inlineBlockBefore( 427 block, dstBlock, rewriter.getInsertionPoint(), blockArgs); 428 auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back()); 429 SmallVector<Value> result(yield.getResults()); 430 rewriter.eraseOp(yield); 431 return result; 432 }); 433 434 userReduc.assign(curResult.begin(), curResult.end()); 435 } 436 } 437 438 rewriter.replaceOp(op, userReduc); 439 return success(); 440 } 441 }; 442 443 } // namespace 444 445 mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() { 446 addConversion([](Type type) { return type; }); 447 addConversion(convertIteratorType); 448 addConversion(convertIterSpaceType); 449 450 addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp, 451 ValueRange inputs, Location loc) -> Value { 452 return builder 453 .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs) 454 .getResult(0); 455 }); 456 } 457 458 void mlir::populateLowerSparseIterationToSCFPatterns( 459 const TypeConverter &converter, RewritePatternSet &patterns) { 460 461 IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext()); 462 patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter, 463 SparseIterateOpConverter, SparseCoIterateOpConverter>( 464 converter, patterns.getContext()); 465 } 466