xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (revision 2b5b3cf60d9e9e0c597bad1be1207b167ef15c9f)
1c42bbda4SPeiming Liu 
2c42bbda4SPeiming Liu #include "Utils/CodegenUtils.h"
3f607102aSPeiming Liu #include "Utils/LoopEmitter.h"
4c42bbda4SPeiming Liu #include "Utils/SparseTensorIterator.h"
5c42bbda4SPeiming Liu 
6951a3630SPeiming Liu #include "mlir/Dialect/MemRef/IR/MemRef.h"
7c42bbda4SPeiming Liu #include "mlir/Dialect/SCF/IR/SCF.h"
8c42bbda4SPeiming Liu #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
9c42bbda4SPeiming Liu #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
10*2b5b3cf6SMatthias Springer #include "mlir/Transforms/DialectConversion.h"
11c42bbda4SPeiming Liu 
12c42bbda4SPeiming Liu using namespace mlir;
13c42bbda4SPeiming Liu using namespace mlir::sparse_tensor;
14c42bbda4SPeiming Liu 
15*2b5b3cf6SMatthias Springer /// Assert that the given value range contains a single value and return it.
16*2b5b3cf6SMatthias Springer static Value getSingleValue(ValueRange values) {
17*2b5b3cf6SMatthias Springer   assert(values.size() == 1 && "expected single value");
18*2b5b3cf6SMatthias Springer   return values.front();
19*2b5b3cf6SMatthias Springer }
20*2b5b3cf6SMatthias Springer 
21951a3630SPeiming Liu static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
22c42bbda4SPeiming Liu                              SmallVectorImpl<Type> &fields) {
23c42bbda4SPeiming Liu   // Position and coordinate buffer in the sparse structure.
24c42bbda4SPeiming Liu   if (enc.getLvlType(lvl).isWithPosLT())
25c42bbda4SPeiming Liu     fields.push_back(enc.getPosMemRefType());
26c42bbda4SPeiming Liu   if (enc.getLvlType(lvl).isWithCrdLT())
27c42bbda4SPeiming Liu     fields.push_back(enc.getCrdMemRefType());
28c42bbda4SPeiming Liu   // One index for shape bound (result from lvlOp).
29c42bbda4SPeiming Liu   fields.push_back(IndexType::get(enc.getContext()));
30c42bbda4SPeiming Liu }
31c42bbda4SPeiming Liu 
32c42bbda4SPeiming Liu static std::optional<LogicalResult>
33c42bbda4SPeiming Liu convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
34c42bbda4SPeiming Liu 
35c42bbda4SPeiming Liu   auto idxTp = IndexType::get(itSp.getContext());
36c42bbda4SPeiming Liu   for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
37c42bbda4SPeiming Liu     convertLevelType(itSp.getEncoding(), l, fields);
38c42bbda4SPeiming Liu 
39c42bbda4SPeiming Liu   // Two indices for lower and upper bound (we only need one pair for the last
40c42bbda4SPeiming Liu   // iteration space).
41c42bbda4SPeiming Liu   fields.append({idxTp, idxTp});
42c42bbda4SPeiming Liu   return success();
43c42bbda4SPeiming Liu }
44c42bbda4SPeiming Liu 
45d6cc35f7SPeiming Liu static std::optional<LogicalResult>
46d6cc35f7SPeiming Liu convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
47d6cc35f7SPeiming Liu   // The actually Iterator Values (that are updated every iteration).
48d6cc35f7SPeiming Liu   auto idxTp = IndexType::get(itTp.getContext());
49d6cc35f7SPeiming Liu   // TODO: handle batch dimension.
50d6cc35f7SPeiming Liu   assert(itTp.getEncoding().getBatchLvlRank() == 0);
51d6cc35f7SPeiming Liu   if (!itTp.isUnique()) {
52d6cc35f7SPeiming Liu     // Segment high for non-unique iterator.
53d6cc35f7SPeiming Liu     fields.push_back(idxTp);
54d6cc35f7SPeiming Liu   }
55d6cc35f7SPeiming Liu   fields.push_back(idxTp);
56d6cc35f7SPeiming Liu   return success();
57d6cc35f7SPeiming Liu }
58d6cc35f7SPeiming Liu 
59f607102aSPeiming Liu static ValueRange
60f607102aSPeiming Liu genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
61f607102aSPeiming Liu                        Value loopCrd,
62f607102aSPeiming Liu                        ArrayRef<std::unique_ptr<SparseIterator>> iters,
63*2b5b3cf6SMatthias Springer                        ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks,
64*2b5b3cf6SMatthias Springer                        ArrayRef<Value> userReduc) {
65*2b5b3cf6SMatthias Springer   if (newBlocks.empty())
66f607102aSPeiming Liu     return userReduc;
67f607102aSPeiming Liu 
68f607102aSPeiming Liu   // The current branch that we are handling.
69*2b5b3cf6SMatthias Springer   Block *newBlock = newBlocks.front();
70*2b5b3cf6SMatthias Springer   Block *oldBlock = oldBlocks.front();
71f607102aSPeiming Liu   Value casePred = constantI1(rewriter, loc, true);
72*2b5b3cf6SMatthias Springer   I64BitSet caseBits =
73*2b5b3cf6SMatthias Springer       op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber());
74f607102aSPeiming Liu   for (unsigned i : caseBits.bits()) {
75f607102aSPeiming Liu     SparseIterator *it = iters[i].get();
76f607102aSPeiming Liu     Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
77f607102aSPeiming Liu                                                 it->getCrd(), loopCrd);
78f607102aSPeiming Liu     casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred);
79f607102aSPeiming Liu   }
80f607102aSPeiming Liu   scf::IfOp ifOp = rewriter.create<scf::IfOp>(
81f607102aSPeiming Liu       loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true);
82f607102aSPeiming Liu   rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
83f607102aSPeiming Liu 
84f607102aSPeiming Liu   // Erase the empty block.
85f607102aSPeiming Liu   rewriter.eraseBlock(&ifOp.getThenRegion().front());
86f607102aSPeiming Liu   // Set up block arguments: user-provided values -> loop coord -> iterators.
87f607102aSPeiming Liu   SmallVector<Value> blockArgs(userReduc);
88f607102aSPeiming Liu   blockArgs.push_back(loopCrd);
89f607102aSPeiming Liu   for (unsigned idx : caseBits.bits())
90f607102aSPeiming Liu     llvm::append_range(blockArgs, iters[idx]->getCursor());
91f607102aSPeiming Liu 
92*2b5b3cf6SMatthias Springer   // Map the old block arguments, because the dialect conversion driver does
93*2b5b3cf6SMatthias Springer   // not immediately perform SSA value replacements. This function is still
94*2b5b3cf6SMatthias Springer   // seeing the old uses.
95f607102aSPeiming Liu   IRMapping mapping;
96*2b5b3cf6SMatthias Springer   for (auto [from, to] : llvm::zip_equal(oldBlock->getArguments(), blockArgs)) {
97f607102aSPeiming Liu     mapping.map(from, to);
98f607102aSPeiming Liu   }
99f607102aSPeiming Liu 
100f607102aSPeiming Liu   // Clone the region, we can not erase the region now because the same region
101f607102aSPeiming Liu   // might be a subcase for multiple lattice point.
102*2b5b3cf6SMatthias Springer   rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(),
103f607102aSPeiming Liu                              ifOp.getThenRegion().begin(), mapping);
104*2b5b3cf6SMatthias Springer   // Remove the block arguments, they were already replaced via `mapping`.
105*2b5b3cf6SMatthias Springer   ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size());
106f607102aSPeiming Liu 
107f607102aSPeiming Liu   // replace sparse_tensor::YieldOp -> scf::YieldOp
108f607102aSPeiming Liu   auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
109f607102aSPeiming Liu   ValueRange yields = spY.getResults();
110f607102aSPeiming Liu   rewriter.eraseOp(spY);
111f607102aSPeiming Liu   rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front());
112f607102aSPeiming Liu   rewriter.create<scf::YieldOp>(loc, yields);
113f607102aSPeiming Liu 
114f607102aSPeiming Liu   // Generates remaining case recursively.
115f607102aSPeiming Liu   rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
116f607102aSPeiming Liu   ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
117*2b5b3cf6SMatthias Springer                                           newBlocks.drop_front(),
118*2b5b3cf6SMatthias Springer                                           oldBlocks.drop_front(), userReduc);
119f607102aSPeiming Liu   if (!res.empty())
120f607102aSPeiming Liu     rewriter.create<scf::YieldOp>(loc, res);
121f607102aSPeiming Liu 
122f607102aSPeiming Liu   rewriter.setInsertionPointAfter(ifOp);
123f607102aSPeiming Liu   return ifOp.getResults();
124f607102aSPeiming Liu }
125f607102aSPeiming Liu 
126f607102aSPeiming Liu static ValueRange genLoopWithIterator(
127f607102aSPeiming Liu     PatternRewriter &rewriter, Location loc, SparseIterator *it,
128b48ef8d8SPeiming Liu     ValueRange reduc,
129f607102aSPeiming Liu     function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
130f607102aSPeiming Liu                                     Region &loopBody, SparseIterator *it,
131f607102aSPeiming Liu                                     ValueRange reduc)>
132f607102aSPeiming Liu         bodyBuilder) {
133f607102aSPeiming Liu   if (it->iteratableByFor()) {
134f607102aSPeiming Liu     auto [lo, hi] = it->genForCond(rewriter, loc);
135f607102aSPeiming Liu     Value step = constantIndex(rewriter, loc, 1);
136*2b5b3cf6SMatthias Springer     scf::ForOp forOp = rewriter.create<scf::ForOp>(
137*2b5b3cf6SMatthias Springer         loc, lo, hi, step, reduc,
138*2b5b3cf6SMatthias Springer         [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
139*2b5b3cf6SMatthias Springer           // Empty builder function to ensure that no terminator is created.
140*2b5b3cf6SMatthias Springer         });
141f607102aSPeiming Liu     {
142f607102aSPeiming Liu       OpBuilder::InsertionGuard guard(rewriter);
143f607102aSPeiming Liu       it->linkNewScope(forOp.getInductionVar());
144f607102aSPeiming Liu       rewriter.setInsertionPointToStart(forOp.getBody());
145f607102aSPeiming Liu       SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
146f607102aSPeiming Liu                                            it, forOp.getRegionIterArgs());
147f607102aSPeiming Liu 
148f607102aSPeiming Liu       rewriter.setInsertionPointToEnd(forOp.getBody());
149f607102aSPeiming Liu       rewriter.create<scf::YieldOp>(loc, ret);
150f607102aSPeiming Liu     }
151f607102aSPeiming Liu     return forOp.getResults();
152f607102aSPeiming Liu   }
153b48ef8d8SPeiming Liu 
154b48ef8d8SPeiming Liu   SmallVector<Value> ivs(reduc);
155f607102aSPeiming Liu   llvm::append_range(ivs, it->getCursor());
156f607102aSPeiming Liu 
157f607102aSPeiming Liu   TypeRange types = ValueRange(ivs).getTypes();
158f607102aSPeiming Liu   auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
159f607102aSPeiming Liu   {
160f607102aSPeiming Liu     OpBuilder::InsertionGuard guard(rewriter);
161f607102aSPeiming Liu     // Generates loop conditions.
162f607102aSPeiming Liu     SmallVector<Location> l(types.size(), loc);
163f607102aSPeiming Liu     Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
164f607102aSPeiming Liu     rewriter.setInsertionPointToStart(before);
165f607102aSPeiming Liu     ValueRange bArgs = before->getArguments();
166f607102aSPeiming Liu     auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
167f607102aSPeiming Liu     rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
168f607102aSPeiming Liu 
169f607102aSPeiming Liu     // Delegates loop body generation.
170f607102aSPeiming Liu     Region &dstRegion = whileOp.getAfter();
171f607102aSPeiming Liu     Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
172f607102aSPeiming Liu     ValueRange aArgs = whileOp.getAfterArguments();
173f607102aSPeiming Liu     it->linkNewScope(aArgs.drop_front(reduc.size()));
174b48ef8d8SPeiming Liu     aArgs = aArgs.take_front(reduc.size());
175f607102aSPeiming Liu 
176f607102aSPeiming Liu     rewriter.setInsertionPointToStart(after);
177f607102aSPeiming Liu     SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
178f607102aSPeiming Liu     rewriter.setInsertionPointToEnd(after);
179f607102aSPeiming Liu 
180f607102aSPeiming Liu     // Forward loops
181f607102aSPeiming Liu     SmallVector<Value> yields;
182f607102aSPeiming Liu     llvm::append_range(yields, ret);
183b48ef8d8SPeiming Liu     llvm::append_range(yields, it->forward(rewriter, loc));
184f607102aSPeiming Liu     rewriter.create<scf::YieldOp>(loc, yields);
185f607102aSPeiming Liu   }
186f607102aSPeiming Liu   return whileOp.getResults().drop_front(it->getCursor().size());
187f607102aSPeiming Liu }
188f607102aSPeiming Liu 
189c42bbda4SPeiming Liu namespace {
190c42bbda4SPeiming Liu 
191c42bbda4SPeiming Liu /// Sparse codegen rule for number of entries operator.
192c42bbda4SPeiming Liu class ExtractIterSpaceConverter
193*2b5b3cf6SMatthias Springer     : public OpConversionPattern<ExtractIterSpaceOp> {
194c42bbda4SPeiming Liu public:
195*2b5b3cf6SMatthias Springer   using OpConversionPattern::OpConversionPattern;
196c42bbda4SPeiming Liu   LogicalResult
197*2b5b3cf6SMatthias Springer   matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
198*2b5b3cf6SMatthias Springer                   ConversionPatternRewriter &rewriter) const override {
199c42bbda4SPeiming Liu     Location loc = op.getLoc();
200c42bbda4SPeiming Liu 
201c42bbda4SPeiming Liu     // Construct the iteration space.
202*2b5b3cf6SMatthias Springer     SparseIterationSpace space(loc, rewriter,
203*2b5b3cf6SMatthias Springer                                getSingleValue(adaptor.getTensor()), 0,
204c42bbda4SPeiming Liu                                op.getLvlRange(), adaptor.getParentIter());
205c42bbda4SPeiming Liu 
206c42bbda4SPeiming Liu     SmallVector<Value> result = space.toValues();
207*2b5b3cf6SMatthias Springer     rewriter.replaceOpWithMultiple(op, {result});
208c42bbda4SPeiming Liu     return success();
209c42bbda4SPeiming Liu   }
210c42bbda4SPeiming Liu };
211c42bbda4SPeiming Liu 
212951a3630SPeiming Liu /// Sparse codegen rule for number of entries operator.
213*2b5b3cf6SMatthias Springer class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
214951a3630SPeiming Liu public:
215*2b5b3cf6SMatthias Springer   using OpConversionPattern::OpConversionPattern;
216951a3630SPeiming Liu   LogicalResult
217*2b5b3cf6SMatthias Springer   matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor,
218*2b5b3cf6SMatthias Springer                   ConversionPatternRewriter &rewriter) const override {
219951a3630SPeiming Liu     Location loc = op.getLoc();
220951a3630SPeiming Liu     Value pos = adaptor.getIterator().back();
221*2b5b3cf6SMatthias Springer     Value valBuf =
222*2b5b3cf6SMatthias Springer         rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
223951a3630SPeiming Liu     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
224951a3630SPeiming Liu     return success();
225951a3630SPeiming Liu   }
226951a3630SPeiming Liu };
227951a3630SPeiming Liu 
228*2b5b3cf6SMatthias Springer class SparseIterateOpConverter : public OpConversionPattern<IterateOp> {
229d6cc35f7SPeiming Liu public:
230*2b5b3cf6SMatthias Springer   using OpConversionPattern::OpConversionPattern;
231d6cc35f7SPeiming Liu   LogicalResult
232*2b5b3cf6SMatthias Springer   matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor,
233*2b5b3cf6SMatthias Springer                   ConversionPatternRewriter &rewriter) const override {
234d6cc35f7SPeiming Liu     if (!op.getCrdUsedLvls().empty())
235d6cc35f7SPeiming Liu       return rewriter.notifyMatchFailure(
236d6cc35f7SPeiming Liu           op, "non-empty coordinates list not implemented.");
237d6cc35f7SPeiming Liu 
238d6cc35f7SPeiming Liu     Location loc = op.getLoc();
239d6cc35f7SPeiming Liu 
240d6cc35f7SPeiming Liu     auto iterSpace = SparseIterationSpace::fromValues(
241d6cc35f7SPeiming Liu         op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
242d6cc35f7SPeiming Liu 
243d6cc35f7SPeiming Liu     std::unique_ptr<SparseIterator> it =
244d6cc35f7SPeiming Liu         iterSpace.extractIterator(rewriter, loc);
245d6cc35f7SPeiming Liu 
246d6cc35f7SPeiming Liu     SmallVector<Value> ivs;
247d6cc35f7SPeiming Liu     for (ValueRange inits : adaptor.getInitArgs())
248d6cc35f7SPeiming Liu       llvm::append_range(ivs, inits);
249d6cc35f7SPeiming Liu 
25071867042SPeiming Liu     // Type conversion on iterate op block.
251*2b5b3cf6SMatthias Springer     unsigned numOrigArgs = op.getBody()->getArgumentTypes().size();
252*2b5b3cf6SMatthias Springer     TypeConverter::SignatureConversion signatureConversion(numOrigArgs);
253d6cc35f7SPeiming Liu     if (failed(typeConverter->convertSignatureArgs(
254*2b5b3cf6SMatthias Springer             op.getBody()->getArgumentTypes(), signatureConversion)))
25571867042SPeiming Liu       return rewriter.notifyMatchFailure(
25671867042SPeiming Liu           op, "failed to convert iterate region argurment types");
257d6cc35f7SPeiming Liu 
258*2b5b3cf6SMatthias Springer     Block *block = rewriter.applySignatureConversion(
259*2b5b3cf6SMatthias Springer         op.getBody(), signatureConversion, getTypeConverter());
26071867042SPeiming Liu     ValueRange ret = genLoopWithIterator(
261b48ef8d8SPeiming Liu         rewriter, loc, it.get(), ivs,
26271867042SPeiming Liu         [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
26371867042SPeiming Liu                 SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
264b48ef8d8SPeiming Liu           SmallVector<Value> blockArgs(reduc);
26571867042SPeiming Liu           // TODO: Also appends coordinates if used.
26671867042SPeiming Liu           // blockArgs.push_back(it->deref(rewriter, loc));
267b48ef8d8SPeiming Liu           llvm::append_range(blockArgs, it->getCursor());
268d6cc35f7SPeiming Liu 
26971867042SPeiming Liu           Block *dstBlock = &loopBody.getBlocks().front();
27071867042SPeiming Liu           rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
27171867042SPeiming Liu                                      blockArgs);
27271867042SPeiming Liu           auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
27371867042SPeiming Liu           // We can not use ValueRange as the operation holding the values will
27471867042SPeiming Liu           // be destoryed.
27571867042SPeiming Liu           SmallVector<Value> result(yield.getResults());
27671867042SPeiming Liu           rewriter.eraseOp(yield);
27771867042SPeiming Liu           return result;
27871867042SPeiming Liu         });
279d6cc35f7SPeiming Liu 
280*2b5b3cf6SMatthias Springer     rewriter.replaceOp(op, ret);
281d6cc35f7SPeiming Liu     return success();
282d6cc35f7SPeiming Liu   }
283d6cc35f7SPeiming Liu };
284d6cc35f7SPeiming Liu 
285*2b5b3cf6SMatthias Springer class SparseCoIterateOpConverter : public OpConversionPattern<CoIterateOp> {
286*2b5b3cf6SMatthias Springer   using OpConversionPattern::OpConversionPattern;
287f607102aSPeiming Liu 
288f607102aSPeiming Liu   LogicalResult
289*2b5b3cf6SMatthias Springer   matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor,
290*2b5b3cf6SMatthias Springer                   ConversionPatternRewriter &rewriter) const override {
291f607102aSPeiming Liu     assert(op.getSpaceDim() == 1 && "Not implemented");
292f607102aSPeiming Liu     Location loc = op.getLoc();
293f607102aSPeiming Liu 
294f607102aSPeiming Liu     I64BitSet denseBits(0);
295f607102aSPeiming Liu     for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes()))
296f607102aSPeiming Liu       if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT))
297f607102aSPeiming Liu         denseBits.set(idx);
298f607102aSPeiming Liu 
299f607102aSPeiming Liu     // If there exists a case that only contains dense spaces. I.e., case
300f607102aSPeiming Liu     // bits is a subset of dense bits, or when there is a full empty case (due
301f607102aSPeiming Liu     // to complements), we need a universal pointer to forward the coiteration
302f607102aSPeiming Liu     // loop.
303f607102aSPeiming Liu     bool needUniv =
304f607102aSPeiming Liu         any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) {
305f607102aSPeiming Liu           // A case for complement.
306f607102aSPeiming Liu           if (caseBits.count() == 0)
307f607102aSPeiming Liu             return true;
308f607102aSPeiming Liu           // An all-dense case.
309f607102aSPeiming Liu           return caseBits.isSubSetOf(denseBits);
310f607102aSPeiming Liu         });
311f607102aSPeiming Liu     assert(!needUniv && "Not implemented");
312f607102aSPeiming Liu     (void)needUniv;
313f607102aSPeiming Liu 
314*2b5b3cf6SMatthias Springer     SmallVector<Block *> newBlocks;
315*2b5b3cf6SMatthias Springer     DenseMap<Block *, Block *> newToOldBlockMap;
316f607102aSPeiming Liu     for (Region &region : op.getCaseRegions()) {
317f607102aSPeiming Liu       // Do a one-shot type conversion on all region blocks, since the same
318f607102aSPeiming Liu       // region might be used multiple time.
319f607102aSPeiming Liu       Block *block = &region.getBlocks().front();
320*2b5b3cf6SMatthias Springer       TypeConverter::SignatureConversion blockTypeMapping(
321*2b5b3cf6SMatthias Springer           block->getArgumentTypes().size());
322f607102aSPeiming Liu       if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
32371867042SPeiming Liu                                                      blockTypeMapping))) {
324f607102aSPeiming Liu         return rewriter.notifyMatchFailure(
325f607102aSPeiming Liu             op, "failed to convert coiterate region argurment types");
32671867042SPeiming Liu       }
327f607102aSPeiming Liu 
328*2b5b3cf6SMatthias Springer       newBlocks.push_back(rewriter.applySignatureConversion(
329*2b5b3cf6SMatthias Springer           block, blockTypeMapping, getTypeConverter()));
330*2b5b3cf6SMatthias Springer       newToOldBlockMap[newBlocks.back()] = block;
331f607102aSPeiming Liu     }
332f607102aSPeiming Liu 
333f607102aSPeiming Liu     SmallVector<SparseIterationSpace> spaces;
334f607102aSPeiming Liu     SmallVector<std::unique_ptr<SparseIterator>> iters;
335f607102aSPeiming Liu     for (auto [spaceTp, spaceVals] : llvm::zip_equal(
336f607102aSPeiming Liu              op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
337f607102aSPeiming Liu       // TODO: do we really need tid?
338f607102aSPeiming Liu       spaces.push_back(SparseIterationSpace::fromValues(
339f607102aSPeiming Liu           cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0));
340f607102aSPeiming Liu       // Extract the iterator.
341f607102aSPeiming Liu       iters.push_back(spaces.back().extractIterator(rewriter, loc));
342f607102aSPeiming Liu     }
343f607102aSPeiming Liu 
344f607102aSPeiming Liu     auto getFilteredIters = [&iters](I64BitSet caseBits) {
345f607102aSPeiming Liu       // Retrives a vector of pointers to the iterators used in the case.
346f607102aSPeiming Liu       SmallVector<SparseIterator *> validIters;
347f607102aSPeiming Liu       for (auto idx : caseBits.bits())
348f607102aSPeiming Liu         validIters.push_back(iters[idx].get());
349f607102aSPeiming Liu       return validIters;
350f607102aSPeiming Liu     };
351f607102aSPeiming Liu 
352f607102aSPeiming Liu     // Get a flattened user-provided loop reduction values.
353f607102aSPeiming Liu     SmallVector<Value> userReduc;
354f607102aSPeiming Liu     for (ValueRange r : adaptor.getInitArgs())
355f607102aSPeiming Liu       llvm::append_range(userReduc, r);
356f607102aSPeiming Liu 
357f607102aSPeiming Liu     // TODO: we need to sort the cases such that they appears in lexical order.
358f607102aSPeiming Liu     // Although sparsification always generates cases in that order, it might
359f607102aSPeiming Liu     // not be the case for human-written code.
360f607102aSPeiming Liu 
361f607102aSPeiming Liu     // Generates a loop sequence, one loop per case.
362f607102aSPeiming Liu     for (auto [r, caseBits] :
363*2b5b3cf6SMatthias Springer          llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) {
364f607102aSPeiming Liu       assert(caseBits.count() > 0 && "Complement space not implemented");
365f607102aSPeiming Liu 
366f607102aSPeiming Liu       // Retrives a vector of pointers to the iterators used in the case.
367f607102aSPeiming Liu       SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits);
368f607102aSPeiming Liu 
369f607102aSPeiming Liu       if (validIters.size() > 1) {
370f607102aSPeiming Liu         auto [loop, loopCrd] =
371f607102aSPeiming Liu             genCoIteration(rewriter, loc, validIters, userReduc,
372f607102aSPeiming Liu                            /*uniIdx=*/nullptr, /*userReducFirst=*/true);
373f607102aSPeiming Liu 
374f607102aSPeiming Liu         // 1st. find all the cases that is a strict subset of the current case
375f607102aSPeiming Liu         // condition, for which we generate one branch per case inside the loop.
376f607102aSPeiming Liu         // The subcases are never empty, it must contains at least the current
377f607102aSPeiming Liu         // region itself.
378f607102aSPeiming Liu         // TODO: these cases should be sorted.
379*2b5b3cf6SMatthias Springer         SmallVector<Region *> subCases =
380*2b5b3cf6SMatthias Springer             op.getSubCasesOf(r->getParent()->getRegionNumber());
381*2b5b3cf6SMatthias Springer         SmallVector<Block *> newBlocks, oldBlocks;
382*2b5b3cf6SMatthias Springer         for (Region *r : subCases) {
383*2b5b3cf6SMatthias Springer           newBlocks.push_back(&r->front());
384*2b5b3cf6SMatthias Springer           oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]);
385*2b5b3cf6SMatthias Springer         }
386f607102aSPeiming Liu         assert(!subCases.empty());
387f607102aSPeiming Liu 
388*2b5b3cf6SMatthias Springer         ValueRange res = genCoIterateBranchNest(
389*2b5b3cf6SMatthias Springer             rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc);
390f607102aSPeiming Liu 
391f607102aSPeiming Liu         SmallVector<Value> nextIterYields(res);
392f607102aSPeiming Liu         // 2nd. foward the loop.
393f607102aSPeiming Liu         for (SparseIterator *it : validIters) {
394f607102aSPeiming Liu           Value cmp = rewriter.create<arith::CmpIOp>(
395f607102aSPeiming Liu               loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
396f607102aSPeiming Liu           it->forwardIf(rewriter, loc, cmp);
397f607102aSPeiming Liu           llvm::append_range(nextIterYields, it->getCursor());
398f607102aSPeiming Liu         }
399f607102aSPeiming Liu         rewriter.create<scf::YieldOp>(loc, nextIterYields);
400f607102aSPeiming Liu 
401f607102aSPeiming Liu         // Exit the loop, relink the iterator SSA value.
402f607102aSPeiming Liu         rewriter.setInsertionPointAfter(loop);
403f607102aSPeiming Liu         ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
404f607102aSPeiming Liu         for (SparseIterator *it : validIters)
405f607102aSPeiming Liu           iterVals = it->linkNewScope(iterVals);
406f607102aSPeiming Liu         assert(iterVals.empty());
407f607102aSPeiming Liu 
408f607102aSPeiming Liu         ValueRange curResult = loop->getResults().take_front(userReduc.size());
409f607102aSPeiming Liu         userReduc.assign(curResult.begin(), curResult.end());
410f607102aSPeiming Liu       } else {
411f607102aSPeiming Liu         // This is a simple iteration loop.
412f607102aSPeiming Liu         assert(caseBits.count() == 1);
413f607102aSPeiming Liu 
414*2b5b3cf6SMatthias Springer         Block *block = r;
415f607102aSPeiming Liu         ValueRange curResult = genLoopWithIterator(
416b48ef8d8SPeiming Liu             rewriter, loc, validIters.front(), userReduc,
417f607102aSPeiming Liu             /*bodyBuilder=*/
418f607102aSPeiming Liu             [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
419f607102aSPeiming Liu                     SparseIterator *it,
420f607102aSPeiming Liu                     ValueRange reduc) -> SmallVector<Value> {
421f607102aSPeiming Liu               SmallVector<Value> blockArgs(reduc);
422f607102aSPeiming Liu               blockArgs.push_back(it->deref(rewriter, loc));
423f607102aSPeiming Liu               llvm::append_range(blockArgs, it->getCursor());
424f607102aSPeiming Liu 
425f607102aSPeiming Liu               Block *dstBlock = &dstRegion.getBlocks().front();
426f607102aSPeiming Liu               rewriter.inlineBlockBefore(
427f607102aSPeiming Liu                   block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
428f607102aSPeiming Liu               auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
429f607102aSPeiming Liu               SmallVector<Value> result(yield.getResults());
430f607102aSPeiming Liu               rewriter.eraseOp(yield);
431f607102aSPeiming Liu               return result;
432f607102aSPeiming Liu             });
433f607102aSPeiming Liu 
434f607102aSPeiming Liu         userReduc.assign(curResult.begin(), curResult.end());
435f607102aSPeiming Liu       }
436f607102aSPeiming Liu     }
437f607102aSPeiming Liu 
438f607102aSPeiming Liu     rewriter.replaceOp(op, userReduc);
439f607102aSPeiming Liu     return success();
440f607102aSPeiming Liu   }
441f607102aSPeiming Liu };
442f607102aSPeiming Liu 
443c42bbda4SPeiming Liu } // namespace
444c42bbda4SPeiming Liu 
445c42bbda4SPeiming Liu mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
446c42bbda4SPeiming Liu   addConversion([](Type type) { return type; });
447d6cc35f7SPeiming Liu   addConversion(convertIteratorType);
448c42bbda4SPeiming Liu   addConversion(convertIterSpaceType);
449c42bbda4SPeiming Liu 
450c42bbda4SPeiming Liu   addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
451f18c3e4eSMatthias Springer                               ValueRange inputs, Location loc) -> Value {
452c42bbda4SPeiming Liu     return builder
453c42bbda4SPeiming Liu         .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
454c42bbda4SPeiming Liu         .getResult(0);
455c42bbda4SPeiming Liu   });
456c42bbda4SPeiming Liu }
457c42bbda4SPeiming Liu 
458c42bbda4SPeiming Liu void mlir::populateLowerSparseIterationToSCFPatterns(
459206fad0eSMatthias Springer     const TypeConverter &converter, RewritePatternSet &patterns) {
460a02010b3SPeiming Liu 
461a02010b3SPeiming Liu   IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
462951a3630SPeiming Liu   patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
463f607102aSPeiming Liu                SparseIterateOpConverter, SparseCoIterateOpConverter>(
464f607102aSPeiming Liu       converter, patterns.getContext());
465c42bbda4SPeiming Liu }
466