1c42bbda4SPeiming Liu 2c42bbda4SPeiming Liu #include "Utils/CodegenUtils.h" 3f607102aSPeiming Liu #include "Utils/LoopEmitter.h" 4c42bbda4SPeiming Liu #include "Utils/SparseTensorIterator.h" 5c42bbda4SPeiming Liu 6951a3630SPeiming Liu #include "mlir/Dialect/MemRef/IR/MemRef.h" 7c42bbda4SPeiming Liu #include "mlir/Dialect/SCF/IR/SCF.h" 8c42bbda4SPeiming Liu #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 9c42bbda4SPeiming Liu #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 10*2b5b3cf6SMatthias Springer #include "mlir/Transforms/DialectConversion.h" 11c42bbda4SPeiming Liu 12c42bbda4SPeiming Liu using namespace mlir; 13c42bbda4SPeiming Liu using namespace mlir::sparse_tensor; 14c42bbda4SPeiming Liu 15*2b5b3cf6SMatthias Springer /// Assert that the given value range contains a single value and return it. 16*2b5b3cf6SMatthias Springer static Value getSingleValue(ValueRange values) { 17*2b5b3cf6SMatthias Springer assert(values.size() == 1 && "expected single value"); 18*2b5b3cf6SMatthias Springer return values.front(); 19*2b5b3cf6SMatthias Springer } 20*2b5b3cf6SMatthias Springer 21951a3630SPeiming Liu static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl, 22c42bbda4SPeiming Liu SmallVectorImpl<Type> &fields) { 23c42bbda4SPeiming Liu // Position and coordinate buffer in the sparse structure. 24c42bbda4SPeiming Liu if (enc.getLvlType(lvl).isWithPosLT()) 25c42bbda4SPeiming Liu fields.push_back(enc.getPosMemRefType()); 26c42bbda4SPeiming Liu if (enc.getLvlType(lvl).isWithCrdLT()) 27c42bbda4SPeiming Liu fields.push_back(enc.getCrdMemRefType()); 28c42bbda4SPeiming Liu // One index for shape bound (result from lvlOp). 29c42bbda4SPeiming Liu fields.push_back(IndexType::get(enc.getContext())); 30c42bbda4SPeiming Liu } 31c42bbda4SPeiming Liu 32c42bbda4SPeiming Liu static std::optional<LogicalResult> 33c42bbda4SPeiming Liu convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) { 34c42bbda4SPeiming Liu 35c42bbda4SPeiming Liu auto idxTp = IndexType::get(itSp.getContext()); 36c42bbda4SPeiming Liu for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++) 37c42bbda4SPeiming Liu convertLevelType(itSp.getEncoding(), l, fields); 38c42bbda4SPeiming Liu 39c42bbda4SPeiming Liu // Two indices for lower and upper bound (we only need one pair for the last 40c42bbda4SPeiming Liu // iteration space). 41c42bbda4SPeiming Liu fields.append({idxTp, idxTp}); 42c42bbda4SPeiming Liu return success(); 43c42bbda4SPeiming Liu } 44c42bbda4SPeiming Liu 45d6cc35f7SPeiming Liu static std::optional<LogicalResult> 46d6cc35f7SPeiming Liu convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) { 47d6cc35f7SPeiming Liu // The actually Iterator Values (that are updated every iteration). 48d6cc35f7SPeiming Liu auto idxTp = IndexType::get(itTp.getContext()); 49d6cc35f7SPeiming Liu // TODO: handle batch dimension. 50d6cc35f7SPeiming Liu assert(itTp.getEncoding().getBatchLvlRank() == 0); 51d6cc35f7SPeiming Liu if (!itTp.isUnique()) { 52d6cc35f7SPeiming Liu // Segment high for non-unique iterator. 53d6cc35f7SPeiming Liu fields.push_back(idxTp); 54d6cc35f7SPeiming Liu } 55d6cc35f7SPeiming Liu fields.push_back(idxTp); 56d6cc35f7SPeiming Liu return success(); 57d6cc35f7SPeiming Liu } 58d6cc35f7SPeiming Liu 59f607102aSPeiming Liu static ValueRange 60f607102aSPeiming Liu genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, 61f607102aSPeiming Liu Value loopCrd, 62f607102aSPeiming Liu ArrayRef<std::unique_ptr<SparseIterator>> iters, 63*2b5b3cf6SMatthias Springer ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks, 64*2b5b3cf6SMatthias Springer ArrayRef<Value> userReduc) { 65*2b5b3cf6SMatthias Springer if (newBlocks.empty()) 66f607102aSPeiming Liu return userReduc; 67f607102aSPeiming Liu 68f607102aSPeiming Liu // The current branch that we are handling. 69*2b5b3cf6SMatthias Springer Block *newBlock = newBlocks.front(); 70*2b5b3cf6SMatthias Springer Block *oldBlock = oldBlocks.front(); 71f607102aSPeiming Liu Value casePred = constantI1(rewriter, loc, true); 72*2b5b3cf6SMatthias Springer I64BitSet caseBits = 73*2b5b3cf6SMatthias Springer op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber()); 74f607102aSPeiming Liu for (unsigned i : caseBits.bits()) { 75f607102aSPeiming Liu SparseIterator *it = iters[i].get(); 76f607102aSPeiming Liu Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 77f607102aSPeiming Liu it->getCrd(), loopCrd); 78f607102aSPeiming Liu casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred); 79f607102aSPeiming Liu } 80f607102aSPeiming Liu scf::IfOp ifOp = rewriter.create<scf::IfOp>( 81f607102aSPeiming Liu loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true); 82f607102aSPeiming Liu rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 83f607102aSPeiming Liu 84f607102aSPeiming Liu // Erase the empty block. 85f607102aSPeiming Liu rewriter.eraseBlock(&ifOp.getThenRegion().front()); 86f607102aSPeiming Liu // Set up block arguments: user-provided values -> loop coord -> iterators. 87f607102aSPeiming Liu SmallVector<Value> blockArgs(userReduc); 88f607102aSPeiming Liu blockArgs.push_back(loopCrd); 89f607102aSPeiming Liu for (unsigned idx : caseBits.bits()) 90f607102aSPeiming Liu llvm::append_range(blockArgs, iters[idx]->getCursor()); 91f607102aSPeiming Liu 92*2b5b3cf6SMatthias Springer // Map the old block arguments, because the dialect conversion driver does 93*2b5b3cf6SMatthias Springer // not immediately perform SSA value replacements. This function is still 94*2b5b3cf6SMatthias Springer // seeing the old uses. 95f607102aSPeiming Liu IRMapping mapping; 96*2b5b3cf6SMatthias Springer for (auto [from, to] : llvm::zip_equal(oldBlock->getArguments(), blockArgs)) { 97f607102aSPeiming Liu mapping.map(from, to); 98f607102aSPeiming Liu } 99f607102aSPeiming Liu 100f607102aSPeiming Liu // Clone the region, we can not erase the region now because the same region 101f607102aSPeiming Liu // might be a subcase for multiple lattice point. 102*2b5b3cf6SMatthias Springer rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(), 103f607102aSPeiming Liu ifOp.getThenRegion().begin(), mapping); 104*2b5b3cf6SMatthias Springer // Remove the block arguments, they were already replaced via `mapping`. 105*2b5b3cf6SMatthias Springer ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size()); 106f607102aSPeiming Liu 107f607102aSPeiming Liu // replace sparse_tensor::YieldOp -> scf::YieldOp 108f607102aSPeiming Liu auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back()); 109f607102aSPeiming Liu ValueRange yields = spY.getResults(); 110f607102aSPeiming Liu rewriter.eraseOp(spY); 111f607102aSPeiming Liu rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front()); 112f607102aSPeiming Liu rewriter.create<scf::YieldOp>(loc, yields); 113f607102aSPeiming Liu 114f607102aSPeiming Liu // Generates remaining case recursively. 115f607102aSPeiming Liu rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 116f607102aSPeiming Liu ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters, 117*2b5b3cf6SMatthias Springer newBlocks.drop_front(), 118*2b5b3cf6SMatthias Springer oldBlocks.drop_front(), userReduc); 119f607102aSPeiming Liu if (!res.empty()) 120f607102aSPeiming Liu rewriter.create<scf::YieldOp>(loc, res); 121f607102aSPeiming Liu 122f607102aSPeiming Liu rewriter.setInsertionPointAfter(ifOp); 123f607102aSPeiming Liu return ifOp.getResults(); 124f607102aSPeiming Liu } 125f607102aSPeiming Liu 126f607102aSPeiming Liu static ValueRange genLoopWithIterator( 127f607102aSPeiming Liu PatternRewriter &rewriter, Location loc, SparseIterator *it, 128b48ef8d8SPeiming Liu ValueRange reduc, 129f607102aSPeiming Liu function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc, 130f607102aSPeiming Liu Region &loopBody, SparseIterator *it, 131f607102aSPeiming Liu ValueRange reduc)> 132f607102aSPeiming Liu bodyBuilder) { 133f607102aSPeiming Liu if (it->iteratableByFor()) { 134f607102aSPeiming Liu auto [lo, hi] = it->genForCond(rewriter, loc); 135f607102aSPeiming Liu Value step = constantIndex(rewriter, loc, 1); 136*2b5b3cf6SMatthias Springer scf::ForOp forOp = rewriter.create<scf::ForOp>( 137*2b5b3cf6SMatthias Springer loc, lo, hi, step, reduc, 138*2b5b3cf6SMatthias Springer [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { 139*2b5b3cf6SMatthias Springer // Empty builder function to ensure that no terminator is created. 140*2b5b3cf6SMatthias Springer }); 141f607102aSPeiming Liu { 142f607102aSPeiming Liu OpBuilder::InsertionGuard guard(rewriter); 143f607102aSPeiming Liu it->linkNewScope(forOp.getInductionVar()); 144f607102aSPeiming Liu rewriter.setInsertionPointToStart(forOp.getBody()); 145f607102aSPeiming Liu SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(), 146f607102aSPeiming Liu it, forOp.getRegionIterArgs()); 147f607102aSPeiming Liu 148f607102aSPeiming Liu rewriter.setInsertionPointToEnd(forOp.getBody()); 149f607102aSPeiming Liu rewriter.create<scf::YieldOp>(loc, ret); 150f607102aSPeiming Liu } 151f607102aSPeiming Liu return forOp.getResults(); 152f607102aSPeiming Liu } 153b48ef8d8SPeiming Liu 154b48ef8d8SPeiming Liu SmallVector<Value> ivs(reduc); 155f607102aSPeiming Liu llvm::append_range(ivs, it->getCursor()); 156f607102aSPeiming Liu 157f607102aSPeiming Liu TypeRange types = ValueRange(ivs).getTypes(); 158f607102aSPeiming Liu auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs); 159f607102aSPeiming Liu { 160f607102aSPeiming Liu OpBuilder::InsertionGuard guard(rewriter); 161f607102aSPeiming Liu // Generates loop conditions. 162f607102aSPeiming Liu SmallVector<Location> l(types.size(), loc); 163f607102aSPeiming Liu Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l); 164f607102aSPeiming Liu rewriter.setInsertionPointToStart(before); 165f607102aSPeiming Liu ValueRange bArgs = before->getArguments(); 166f607102aSPeiming Liu auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs); 167f607102aSPeiming Liu rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments()); 168f607102aSPeiming Liu 169f607102aSPeiming Liu // Delegates loop body generation. 170f607102aSPeiming Liu Region &dstRegion = whileOp.getAfter(); 171f607102aSPeiming Liu Block *after = rewriter.createBlock(&dstRegion, {}, types, l); 172f607102aSPeiming Liu ValueRange aArgs = whileOp.getAfterArguments(); 173f607102aSPeiming Liu it->linkNewScope(aArgs.drop_front(reduc.size())); 174b48ef8d8SPeiming Liu aArgs = aArgs.take_front(reduc.size()); 175f607102aSPeiming Liu 176f607102aSPeiming Liu rewriter.setInsertionPointToStart(after); 177f607102aSPeiming Liu SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs); 178f607102aSPeiming Liu rewriter.setInsertionPointToEnd(after); 179f607102aSPeiming Liu 180f607102aSPeiming Liu // Forward loops 181f607102aSPeiming Liu SmallVector<Value> yields; 182f607102aSPeiming Liu llvm::append_range(yields, ret); 183b48ef8d8SPeiming Liu llvm::append_range(yields, it->forward(rewriter, loc)); 184f607102aSPeiming Liu rewriter.create<scf::YieldOp>(loc, yields); 185f607102aSPeiming Liu } 186f607102aSPeiming Liu return whileOp.getResults().drop_front(it->getCursor().size()); 187f607102aSPeiming Liu } 188f607102aSPeiming Liu 189c42bbda4SPeiming Liu namespace { 190c42bbda4SPeiming Liu 191c42bbda4SPeiming Liu /// Sparse codegen rule for number of entries operator. 192c42bbda4SPeiming Liu class ExtractIterSpaceConverter 193*2b5b3cf6SMatthias Springer : public OpConversionPattern<ExtractIterSpaceOp> { 194c42bbda4SPeiming Liu public: 195*2b5b3cf6SMatthias Springer using OpConversionPattern::OpConversionPattern; 196c42bbda4SPeiming Liu LogicalResult 197*2b5b3cf6SMatthias Springer matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor, 198*2b5b3cf6SMatthias Springer ConversionPatternRewriter &rewriter) const override { 199c42bbda4SPeiming Liu Location loc = op.getLoc(); 200c42bbda4SPeiming Liu 201c42bbda4SPeiming Liu // Construct the iteration space. 202*2b5b3cf6SMatthias Springer SparseIterationSpace space(loc, rewriter, 203*2b5b3cf6SMatthias Springer getSingleValue(adaptor.getTensor()), 0, 204c42bbda4SPeiming Liu op.getLvlRange(), adaptor.getParentIter()); 205c42bbda4SPeiming Liu 206c42bbda4SPeiming Liu SmallVector<Value> result = space.toValues(); 207*2b5b3cf6SMatthias Springer rewriter.replaceOpWithMultiple(op, {result}); 208c42bbda4SPeiming Liu return success(); 209c42bbda4SPeiming Liu } 210c42bbda4SPeiming Liu }; 211c42bbda4SPeiming Liu 212951a3630SPeiming Liu /// Sparse codegen rule for number of entries operator. 213*2b5b3cf6SMatthias Springer class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> { 214951a3630SPeiming Liu public: 215*2b5b3cf6SMatthias Springer using OpConversionPattern::OpConversionPattern; 216951a3630SPeiming Liu LogicalResult 217*2b5b3cf6SMatthias Springer matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor, 218*2b5b3cf6SMatthias Springer ConversionPatternRewriter &rewriter) const override { 219951a3630SPeiming Liu Location loc = op.getLoc(); 220951a3630SPeiming Liu Value pos = adaptor.getIterator().back(); 221*2b5b3cf6SMatthias Springer Value valBuf = 222*2b5b3cf6SMatthias Springer rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor())); 223951a3630SPeiming Liu rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos); 224951a3630SPeiming Liu return success(); 225951a3630SPeiming Liu } 226951a3630SPeiming Liu }; 227951a3630SPeiming Liu 228*2b5b3cf6SMatthias Springer class SparseIterateOpConverter : public OpConversionPattern<IterateOp> { 229d6cc35f7SPeiming Liu public: 230*2b5b3cf6SMatthias Springer using OpConversionPattern::OpConversionPattern; 231d6cc35f7SPeiming Liu LogicalResult 232*2b5b3cf6SMatthias Springer matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor, 233*2b5b3cf6SMatthias Springer ConversionPatternRewriter &rewriter) const override { 234d6cc35f7SPeiming Liu if (!op.getCrdUsedLvls().empty()) 235d6cc35f7SPeiming Liu return rewriter.notifyMatchFailure( 236d6cc35f7SPeiming Liu op, "non-empty coordinates list not implemented."); 237d6cc35f7SPeiming Liu 238d6cc35f7SPeiming Liu Location loc = op.getLoc(); 239d6cc35f7SPeiming Liu 240d6cc35f7SPeiming Liu auto iterSpace = SparseIterationSpace::fromValues( 241d6cc35f7SPeiming Liu op.getIterSpace().getType(), adaptor.getIterSpace(), 0); 242d6cc35f7SPeiming Liu 243d6cc35f7SPeiming Liu std::unique_ptr<SparseIterator> it = 244d6cc35f7SPeiming Liu iterSpace.extractIterator(rewriter, loc); 245d6cc35f7SPeiming Liu 246d6cc35f7SPeiming Liu SmallVector<Value> ivs; 247d6cc35f7SPeiming Liu for (ValueRange inits : adaptor.getInitArgs()) 248d6cc35f7SPeiming Liu llvm::append_range(ivs, inits); 249d6cc35f7SPeiming Liu 25071867042SPeiming Liu // Type conversion on iterate op block. 251*2b5b3cf6SMatthias Springer unsigned numOrigArgs = op.getBody()->getArgumentTypes().size(); 252*2b5b3cf6SMatthias Springer TypeConverter::SignatureConversion signatureConversion(numOrigArgs); 253d6cc35f7SPeiming Liu if (failed(typeConverter->convertSignatureArgs( 254*2b5b3cf6SMatthias Springer op.getBody()->getArgumentTypes(), signatureConversion))) 25571867042SPeiming Liu return rewriter.notifyMatchFailure( 25671867042SPeiming Liu op, "failed to convert iterate region argurment types"); 257d6cc35f7SPeiming Liu 258*2b5b3cf6SMatthias Springer Block *block = rewriter.applySignatureConversion( 259*2b5b3cf6SMatthias Springer op.getBody(), signatureConversion, getTypeConverter()); 26071867042SPeiming Liu ValueRange ret = genLoopWithIterator( 261b48ef8d8SPeiming Liu rewriter, loc, it.get(), ivs, 26271867042SPeiming Liu [block](PatternRewriter &rewriter, Location loc, Region &loopBody, 26371867042SPeiming Liu SparseIterator *it, ValueRange reduc) -> SmallVector<Value> { 264b48ef8d8SPeiming Liu SmallVector<Value> blockArgs(reduc); 26571867042SPeiming Liu // TODO: Also appends coordinates if used. 26671867042SPeiming Liu // blockArgs.push_back(it->deref(rewriter, loc)); 267b48ef8d8SPeiming Liu llvm::append_range(blockArgs, it->getCursor()); 268d6cc35f7SPeiming Liu 26971867042SPeiming Liu Block *dstBlock = &loopBody.getBlocks().front(); 27071867042SPeiming Liu rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(), 27171867042SPeiming Liu blockArgs); 27271867042SPeiming Liu auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back()); 27371867042SPeiming Liu // We can not use ValueRange as the operation holding the values will 27471867042SPeiming Liu // be destoryed. 27571867042SPeiming Liu SmallVector<Value> result(yield.getResults()); 27671867042SPeiming Liu rewriter.eraseOp(yield); 27771867042SPeiming Liu return result; 27871867042SPeiming Liu }); 279d6cc35f7SPeiming Liu 280*2b5b3cf6SMatthias Springer rewriter.replaceOp(op, ret); 281d6cc35f7SPeiming Liu return success(); 282d6cc35f7SPeiming Liu } 283d6cc35f7SPeiming Liu }; 284d6cc35f7SPeiming Liu 285*2b5b3cf6SMatthias Springer class SparseCoIterateOpConverter : public OpConversionPattern<CoIterateOp> { 286*2b5b3cf6SMatthias Springer using OpConversionPattern::OpConversionPattern; 287f607102aSPeiming Liu 288f607102aSPeiming Liu LogicalResult 289*2b5b3cf6SMatthias Springer matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor, 290*2b5b3cf6SMatthias Springer ConversionPatternRewriter &rewriter) const override { 291f607102aSPeiming Liu assert(op.getSpaceDim() == 1 && "Not implemented"); 292f607102aSPeiming Liu Location loc = op.getLoc(); 293f607102aSPeiming Liu 294f607102aSPeiming Liu I64BitSet denseBits(0); 295f607102aSPeiming Liu for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes())) 296f607102aSPeiming Liu if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT)) 297f607102aSPeiming Liu denseBits.set(idx); 298f607102aSPeiming Liu 299f607102aSPeiming Liu // If there exists a case that only contains dense spaces. I.e., case 300f607102aSPeiming Liu // bits is a subset of dense bits, or when there is a full empty case (due 301f607102aSPeiming Liu // to complements), we need a universal pointer to forward the coiteration 302f607102aSPeiming Liu // loop. 303f607102aSPeiming Liu bool needUniv = 304f607102aSPeiming Liu any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) { 305f607102aSPeiming Liu // A case for complement. 306f607102aSPeiming Liu if (caseBits.count() == 0) 307f607102aSPeiming Liu return true; 308f607102aSPeiming Liu // An all-dense case. 309f607102aSPeiming Liu return caseBits.isSubSetOf(denseBits); 310f607102aSPeiming Liu }); 311f607102aSPeiming Liu assert(!needUniv && "Not implemented"); 312f607102aSPeiming Liu (void)needUniv; 313f607102aSPeiming Liu 314*2b5b3cf6SMatthias Springer SmallVector<Block *> newBlocks; 315*2b5b3cf6SMatthias Springer DenseMap<Block *, Block *> newToOldBlockMap; 316f607102aSPeiming Liu for (Region ®ion : op.getCaseRegions()) { 317f607102aSPeiming Liu // Do a one-shot type conversion on all region blocks, since the same 318f607102aSPeiming Liu // region might be used multiple time. 319f607102aSPeiming Liu Block *block = ®ion.getBlocks().front(); 320*2b5b3cf6SMatthias Springer TypeConverter::SignatureConversion blockTypeMapping( 321*2b5b3cf6SMatthias Springer block->getArgumentTypes().size()); 322f607102aSPeiming Liu if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), 32371867042SPeiming Liu blockTypeMapping))) { 324f607102aSPeiming Liu return rewriter.notifyMatchFailure( 325f607102aSPeiming Liu op, "failed to convert coiterate region argurment types"); 32671867042SPeiming Liu } 327f607102aSPeiming Liu 328*2b5b3cf6SMatthias Springer newBlocks.push_back(rewriter.applySignatureConversion( 329*2b5b3cf6SMatthias Springer block, blockTypeMapping, getTypeConverter())); 330*2b5b3cf6SMatthias Springer newToOldBlockMap[newBlocks.back()] = block; 331f607102aSPeiming Liu } 332f607102aSPeiming Liu 333f607102aSPeiming Liu SmallVector<SparseIterationSpace> spaces; 334f607102aSPeiming Liu SmallVector<std::unique_ptr<SparseIterator>> iters; 335f607102aSPeiming Liu for (auto [spaceTp, spaceVals] : llvm::zip_equal( 336f607102aSPeiming Liu op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) { 337f607102aSPeiming Liu // TODO: do we really need tid? 338f607102aSPeiming Liu spaces.push_back(SparseIterationSpace::fromValues( 339f607102aSPeiming Liu cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0)); 340f607102aSPeiming Liu // Extract the iterator. 341f607102aSPeiming Liu iters.push_back(spaces.back().extractIterator(rewriter, loc)); 342f607102aSPeiming Liu } 343f607102aSPeiming Liu 344f607102aSPeiming Liu auto getFilteredIters = [&iters](I64BitSet caseBits) { 345f607102aSPeiming Liu // Retrives a vector of pointers to the iterators used in the case. 346f607102aSPeiming Liu SmallVector<SparseIterator *> validIters; 347f607102aSPeiming Liu for (auto idx : caseBits.bits()) 348f607102aSPeiming Liu validIters.push_back(iters[idx].get()); 349f607102aSPeiming Liu return validIters; 350f607102aSPeiming Liu }; 351f607102aSPeiming Liu 352f607102aSPeiming Liu // Get a flattened user-provided loop reduction values. 353f607102aSPeiming Liu SmallVector<Value> userReduc; 354f607102aSPeiming Liu for (ValueRange r : adaptor.getInitArgs()) 355f607102aSPeiming Liu llvm::append_range(userReduc, r); 356f607102aSPeiming Liu 357f607102aSPeiming Liu // TODO: we need to sort the cases such that they appears in lexical order. 358f607102aSPeiming Liu // Although sparsification always generates cases in that order, it might 359f607102aSPeiming Liu // not be the case for human-written code. 360f607102aSPeiming Liu 361f607102aSPeiming Liu // Generates a loop sequence, one loop per case. 362f607102aSPeiming Liu for (auto [r, caseBits] : 363*2b5b3cf6SMatthias Springer llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) { 364f607102aSPeiming Liu assert(caseBits.count() > 0 && "Complement space not implemented"); 365f607102aSPeiming Liu 366f607102aSPeiming Liu // Retrives a vector of pointers to the iterators used in the case. 367f607102aSPeiming Liu SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits); 368f607102aSPeiming Liu 369f607102aSPeiming Liu if (validIters.size() > 1) { 370f607102aSPeiming Liu auto [loop, loopCrd] = 371f607102aSPeiming Liu genCoIteration(rewriter, loc, validIters, userReduc, 372f607102aSPeiming Liu /*uniIdx=*/nullptr, /*userReducFirst=*/true); 373f607102aSPeiming Liu 374f607102aSPeiming Liu // 1st. find all the cases that is a strict subset of the current case 375f607102aSPeiming Liu // condition, for which we generate one branch per case inside the loop. 376f607102aSPeiming Liu // The subcases are never empty, it must contains at least the current 377f607102aSPeiming Liu // region itself. 378f607102aSPeiming Liu // TODO: these cases should be sorted. 379*2b5b3cf6SMatthias Springer SmallVector<Region *> subCases = 380*2b5b3cf6SMatthias Springer op.getSubCasesOf(r->getParent()->getRegionNumber()); 381*2b5b3cf6SMatthias Springer SmallVector<Block *> newBlocks, oldBlocks; 382*2b5b3cf6SMatthias Springer for (Region *r : subCases) { 383*2b5b3cf6SMatthias Springer newBlocks.push_back(&r->front()); 384*2b5b3cf6SMatthias Springer oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]); 385*2b5b3cf6SMatthias Springer } 386f607102aSPeiming Liu assert(!subCases.empty()); 387f607102aSPeiming Liu 388*2b5b3cf6SMatthias Springer ValueRange res = genCoIterateBranchNest( 389*2b5b3cf6SMatthias Springer rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc); 390f607102aSPeiming Liu 391f607102aSPeiming Liu SmallVector<Value> nextIterYields(res); 392f607102aSPeiming Liu // 2nd. foward the loop. 393f607102aSPeiming Liu for (SparseIterator *it : validIters) { 394f607102aSPeiming Liu Value cmp = rewriter.create<arith::CmpIOp>( 395f607102aSPeiming Liu loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd); 396f607102aSPeiming Liu it->forwardIf(rewriter, loc, cmp); 397f607102aSPeiming Liu llvm::append_range(nextIterYields, it->getCursor()); 398f607102aSPeiming Liu } 399f607102aSPeiming Liu rewriter.create<scf::YieldOp>(loc, nextIterYields); 400f607102aSPeiming Liu 401f607102aSPeiming Liu // Exit the loop, relink the iterator SSA value. 402f607102aSPeiming Liu rewriter.setInsertionPointAfter(loop); 403f607102aSPeiming Liu ValueRange iterVals = loop->getResults().drop_front(userReduc.size()); 404f607102aSPeiming Liu for (SparseIterator *it : validIters) 405f607102aSPeiming Liu iterVals = it->linkNewScope(iterVals); 406f607102aSPeiming Liu assert(iterVals.empty()); 407f607102aSPeiming Liu 408f607102aSPeiming Liu ValueRange curResult = loop->getResults().take_front(userReduc.size()); 409f607102aSPeiming Liu userReduc.assign(curResult.begin(), curResult.end()); 410f607102aSPeiming Liu } else { 411f607102aSPeiming Liu // This is a simple iteration loop. 412f607102aSPeiming Liu assert(caseBits.count() == 1); 413f607102aSPeiming Liu 414*2b5b3cf6SMatthias Springer Block *block = r; 415f607102aSPeiming Liu ValueRange curResult = genLoopWithIterator( 416b48ef8d8SPeiming Liu rewriter, loc, validIters.front(), userReduc, 417f607102aSPeiming Liu /*bodyBuilder=*/ 418f607102aSPeiming Liu [block](PatternRewriter &rewriter, Location loc, Region &dstRegion, 419f607102aSPeiming Liu SparseIterator *it, 420f607102aSPeiming Liu ValueRange reduc) -> SmallVector<Value> { 421f607102aSPeiming Liu SmallVector<Value> blockArgs(reduc); 422f607102aSPeiming Liu blockArgs.push_back(it->deref(rewriter, loc)); 423f607102aSPeiming Liu llvm::append_range(blockArgs, it->getCursor()); 424f607102aSPeiming Liu 425f607102aSPeiming Liu Block *dstBlock = &dstRegion.getBlocks().front(); 426f607102aSPeiming Liu rewriter.inlineBlockBefore( 427f607102aSPeiming Liu block, dstBlock, rewriter.getInsertionPoint(), blockArgs); 428f607102aSPeiming Liu auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back()); 429f607102aSPeiming Liu SmallVector<Value> result(yield.getResults()); 430f607102aSPeiming Liu rewriter.eraseOp(yield); 431f607102aSPeiming Liu return result; 432f607102aSPeiming Liu }); 433f607102aSPeiming Liu 434f607102aSPeiming Liu userReduc.assign(curResult.begin(), curResult.end()); 435f607102aSPeiming Liu } 436f607102aSPeiming Liu } 437f607102aSPeiming Liu 438f607102aSPeiming Liu rewriter.replaceOp(op, userReduc); 439f607102aSPeiming Liu return success(); 440f607102aSPeiming Liu } 441f607102aSPeiming Liu }; 442f607102aSPeiming Liu 443c42bbda4SPeiming Liu } // namespace 444c42bbda4SPeiming Liu 445c42bbda4SPeiming Liu mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() { 446c42bbda4SPeiming Liu addConversion([](Type type) { return type; }); 447d6cc35f7SPeiming Liu addConversion(convertIteratorType); 448c42bbda4SPeiming Liu addConversion(convertIterSpaceType); 449c42bbda4SPeiming Liu 450c42bbda4SPeiming Liu addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp, 451f18c3e4eSMatthias Springer ValueRange inputs, Location loc) -> Value { 452c42bbda4SPeiming Liu return builder 453c42bbda4SPeiming Liu .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs) 454c42bbda4SPeiming Liu .getResult(0); 455c42bbda4SPeiming Liu }); 456c42bbda4SPeiming Liu } 457c42bbda4SPeiming Liu 458c42bbda4SPeiming Liu void mlir::populateLowerSparseIterationToSCFPatterns( 459206fad0eSMatthias Springer const TypeConverter &converter, RewritePatternSet &patterns) { 460a02010b3SPeiming Liu 461a02010b3SPeiming Liu IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext()); 462951a3630SPeiming Liu patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter, 463f607102aSPeiming Liu SparseIterateOpConverter, SparseCoIterateOpConverter>( 464f607102aSPeiming Liu converter, patterns.getContext()); 465c42bbda4SPeiming Liu } 466