xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (revision 2b5b3cf60d9e9e0c597bad1be1207b167ef15c9f)
1 
2 #include "Utils/CodegenUtils.h"
3 #include "Utils/LoopEmitter.h"
4 #include "Utils/SparseTensorIterator.h"
5 
6 #include "mlir/Dialect/MemRef/IR/MemRef.h"
7 #include "mlir/Dialect/SCF/IR/SCF.h"
8 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
9 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
10 #include "mlir/Transforms/DialectConversion.h"
11 
12 using namespace mlir;
13 using namespace mlir::sparse_tensor;
14 
15 /// Assert that the given value range contains a single value and return it.
16 static Value getSingleValue(ValueRange values) {
17   assert(values.size() == 1 && "expected single value");
18   return values.front();
19 }
20 
21 static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
22                              SmallVectorImpl<Type> &fields) {
23   // Position and coordinate buffer in the sparse structure.
24   if (enc.getLvlType(lvl).isWithPosLT())
25     fields.push_back(enc.getPosMemRefType());
26   if (enc.getLvlType(lvl).isWithCrdLT())
27     fields.push_back(enc.getCrdMemRefType());
28   // One index for shape bound (result from lvlOp).
29   fields.push_back(IndexType::get(enc.getContext()));
30 }
31 
32 static std::optional<LogicalResult>
33 convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
34 
35   auto idxTp = IndexType::get(itSp.getContext());
36   for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
37     convertLevelType(itSp.getEncoding(), l, fields);
38 
39   // Two indices for lower and upper bound (we only need one pair for the last
40   // iteration space).
41   fields.append({idxTp, idxTp});
42   return success();
43 }
44 
45 static std::optional<LogicalResult>
46 convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
47   // The actually Iterator Values (that are updated every iteration).
48   auto idxTp = IndexType::get(itTp.getContext());
49   // TODO: handle batch dimension.
50   assert(itTp.getEncoding().getBatchLvlRank() == 0);
51   if (!itTp.isUnique()) {
52     // Segment high for non-unique iterator.
53     fields.push_back(idxTp);
54   }
55   fields.push_back(idxTp);
56   return success();
57 }
58 
59 static ValueRange
60 genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
61                        Value loopCrd,
62                        ArrayRef<std::unique_ptr<SparseIterator>> iters,
63                        ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks,
64                        ArrayRef<Value> userReduc) {
65   if (newBlocks.empty())
66     return userReduc;
67 
68   // The current branch that we are handling.
69   Block *newBlock = newBlocks.front();
70   Block *oldBlock = oldBlocks.front();
71   Value casePred = constantI1(rewriter, loc, true);
72   I64BitSet caseBits =
73       op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber());
74   for (unsigned i : caseBits.bits()) {
75     SparseIterator *it = iters[i].get();
76     Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
77                                                 it->getCrd(), loopCrd);
78     casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred);
79   }
80   scf::IfOp ifOp = rewriter.create<scf::IfOp>(
81       loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true);
82   rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
83 
84   // Erase the empty block.
85   rewriter.eraseBlock(&ifOp.getThenRegion().front());
86   // Set up block arguments: user-provided values -> loop coord -> iterators.
87   SmallVector<Value> blockArgs(userReduc);
88   blockArgs.push_back(loopCrd);
89   for (unsigned idx : caseBits.bits())
90     llvm::append_range(blockArgs, iters[idx]->getCursor());
91 
92   // Map the old block arguments, because the dialect conversion driver does
93   // not immediately perform SSA value replacements. This function is still
94   // seeing the old uses.
95   IRMapping mapping;
96   for (auto [from, to] : llvm::zip_equal(oldBlock->getArguments(), blockArgs)) {
97     mapping.map(from, to);
98   }
99 
100   // Clone the region, we can not erase the region now because the same region
101   // might be a subcase for multiple lattice point.
102   rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(),
103                              ifOp.getThenRegion().begin(), mapping);
104   // Remove the block arguments, they were already replaced via `mapping`.
105   ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size());
106 
107   // replace sparse_tensor::YieldOp -> scf::YieldOp
108   auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
109   ValueRange yields = spY.getResults();
110   rewriter.eraseOp(spY);
111   rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front());
112   rewriter.create<scf::YieldOp>(loc, yields);
113 
114   // Generates remaining case recursively.
115   rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
116   ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
117                                           newBlocks.drop_front(),
118                                           oldBlocks.drop_front(), userReduc);
119   if (!res.empty())
120     rewriter.create<scf::YieldOp>(loc, res);
121 
122   rewriter.setInsertionPointAfter(ifOp);
123   return ifOp.getResults();
124 }
125 
126 static ValueRange genLoopWithIterator(
127     PatternRewriter &rewriter, Location loc, SparseIterator *it,
128     ValueRange reduc,
129     function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
130                                     Region &loopBody, SparseIterator *it,
131                                     ValueRange reduc)>
132         bodyBuilder) {
133   if (it->iteratableByFor()) {
134     auto [lo, hi] = it->genForCond(rewriter, loc);
135     Value step = constantIndex(rewriter, loc, 1);
136     scf::ForOp forOp = rewriter.create<scf::ForOp>(
137         loc, lo, hi, step, reduc,
138         [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
139           // Empty builder function to ensure that no terminator is created.
140         });
141     {
142       OpBuilder::InsertionGuard guard(rewriter);
143       it->linkNewScope(forOp.getInductionVar());
144       rewriter.setInsertionPointToStart(forOp.getBody());
145       SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
146                                            it, forOp.getRegionIterArgs());
147 
148       rewriter.setInsertionPointToEnd(forOp.getBody());
149       rewriter.create<scf::YieldOp>(loc, ret);
150     }
151     return forOp.getResults();
152   }
153 
154   SmallVector<Value> ivs(reduc);
155   llvm::append_range(ivs, it->getCursor());
156 
157   TypeRange types = ValueRange(ivs).getTypes();
158   auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
159   {
160     OpBuilder::InsertionGuard guard(rewriter);
161     // Generates loop conditions.
162     SmallVector<Location> l(types.size(), loc);
163     Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
164     rewriter.setInsertionPointToStart(before);
165     ValueRange bArgs = before->getArguments();
166     auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
167     rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
168 
169     // Delegates loop body generation.
170     Region &dstRegion = whileOp.getAfter();
171     Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
172     ValueRange aArgs = whileOp.getAfterArguments();
173     it->linkNewScope(aArgs.drop_front(reduc.size()));
174     aArgs = aArgs.take_front(reduc.size());
175 
176     rewriter.setInsertionPointToStart(after);
177     SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
178     rewriter.setInsertionPointToEnd(after);
179 
180     // Forward loops
181     SmallVector<Value> yields;
182     llvm::append_range(yields, ret);
183     llvm::append_range(yields, it->forward(rewriter, loc));
184     rewriter.create<scf::YieldOp>(loc, yields);
185   }
186   return whileOp.getResults().drop_front(it->getCursor().size());
187 }
188 
189 namespace {
190 
191 /// Sparse codegen rule for number of entries operator.
192 class ExtractIterSpaceConverter
193     : public OpConversionPattern<ExtractIterSpaceOp> {
194 public:
195   using OpConversionPattern::OpConversionPattern;
196   LogicalResult
197   matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
198                   ConversionPatternRewriter &rewriter) const override {
199     Location loc = op.getLoc();
200 
201     // Construct the iteration space.
202     SparseIterationSpace space(loc, rewriter,
203                                getSingleValue(adaptor.getTensor()), 0,
204                                op.getLvlRange(), adaptor.getParentIter());
205 
206     SmallVector<Value> result = space.toValues();
207     rewriter.replaceOpWithMultiple(op, {result});
208     return success();
209   }
210 };
211 
212 /// Sparse codegen rule for number of entries operator.
213 class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
214 public:
215   using OpConversionPattern::OpConversionPattern;
216   LogicalResult
217   matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor,
218                   ConversionPatternRewriter &rewriter) const override {
219     Location loc = op.getLoc();
220     Value pos = adaptor.getIterator().back();
221     Value valBuf =
222         rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
223     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
224     return success();
225   }
226 };
227 
228 class SparseIterateOpConverter : public OpConversionPattern<IterateOp> {
229 public:
230   using OpConversionPattern::OpConversionPattern;
231   LogicalResult
232   matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor,
233                   ConversionPatternRewriter &rewriter) const override {
234     if (!op.getCrdUsedLvls().empty())
235       return rewriter.notifyMatchFailure(
236           op, "non-empty coordinates list not implemented.");
237 
238     Location loc = op.getLoc();
239 
240     auto iterSpace = SparseIterationSpace::fromValues(
241         op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
242 
243     std::unique_ptr<SparseIterator> it =
244         iterSpace.extractIterator(rewriter, loc);
245 
246     SmallVector<Value> ivs;
247     for (ValueRange inits : adaptor.getInitArgs())
248       llvm::append_range(ivs, inits);
249 
250     // Type conversion on iterate op block.
251     unsigned numOrigArgs = op.getBody()->getArgumentTypes().size();
252     TypeConverter::SignatureConversion signatureConversion(numOrigArgs);
253     if (failed(typeConverter->convertSignatureArgs(
254             op.getBody()->getArgumentTypes(), signatureConversion)))
255       return rewriter.notifyMatchFailure(
256           op, "failed to convert iterate region argurment types");
257 
258     Block *block = rewriter.applySignatureConversion(
259         op.getBody(), signatureConversion, getTypeConverter());
260     ValueRange ret = genLoopWithIterator(
261         rewriter, loc, it.get(), ivs,
262         [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
263                 SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
264           SmallVector<Value> blockArgs(reduc);
265           // TODO: Also appends coordinates if used.
266           // blockArgs.push_back(it->deref(rewriter, loc));
267           llvm::append_range(blockArgs, it->getCursor());
268 
269           Block *dstBlock = &loopBody.getBlocks().front();
270           rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
271                                      blockArgs);
272           auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
273           // We can not use ValueRange as the operation holding the values will
274           // be destoryed.
275           SmallVector<Value> result(yield.getResults());
276           rewriter.eraseOp(yield);
277           return result;
278         });
279 
280     rewriter.replaceOp(op, ret);
281     return success();
282   }
283 };
284 
285 class SparseCoIterateOpConverter : public OpConversionPattern<CoIterateOp> {
286   using OpConversionPattern::OpConversionPattern;
287 
288   LogicalResult
289   matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor,
290                   ConversionPatternRewriter &rewriter) const override {
291     assert(op.getSpaceDim() == 1 && "Not implemented");
292     Location loc = op.getLoc();
293 
294     I64BitSet denseBits(0);
295     for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes()))
296       if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT))
297         denseBits.set(idx);
298 
299     // If there exists a case that only contains dense spaces. I.e., case
300     // bits is a subset of dense bits, or when there is a full empty case (due
301     // to complements), we need a universal pointer to forward the coiteration
302     // loop.
303     bool needUniv =
304         any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) {
305           // A case for complement.
306           if (caseBits.count() == 0)
307             return true;
308           // An all-dense case.
309           return caseBits.isSubSetOf(denseBits);
310         });
311     assert(!needUniv && "Not implemented");
312     (void)needUniv;
313 
314     SmallVector<Block *> newBlocks;
315     DenseMap<Block *, Block *> newToOldBlockMap;
316     for (Region &region : op.getCaseRegions()) {
317       // Do a one-shot type conversion on all region blocks, since the same
318       // region might be used multiple time.
319       Block *block = &region.getBlocks().front();
320       TypeConverter::SignatureConversion blockTypeMapping(
321           block->getArgumentTypes().size());
322       if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
323                                                      blockTypeMapping))) {
324         return rewriter.notifyMatchFailure(
325             op, "failed to convert coiterate region argurment types");
326       }
327 
328       newBlocks.push_back(rewriter.applySignatureConversion(
329           block, blockTypeMapping, getTypeConverter()));
330       newToOldBlockMap[newBlocks.back()] = block;
331     }
332 
333     SmallVector<SparseIterationSpace> spaces;
334     SmallVector<std::unique_ptr<SparseIterator>> iters;
335     for (auto [spaceTp, spaceVals] : llvm::zip_equal(
336              op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
337       // TODO: do we really need tid?
338       spaces.push_back(SparseIterationSpace::fromValues(
339           cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0));
340       // Extract the iterator.
341       iters.push_back(spaces.back().extractIterator(rewriter, loc));
342     }
343 
344     auto getFilteredIters = [&iters](I64BitSet caseBits) {
345       // Retrives a vector of pointers to the iterators used in the case.
346       SmallVector<SparseIterator *> validIters;
347       for (auto idx : caseBits.bits())
348         validIters.push_back(iters[idx].get());
349       return validIters;
350     };
351 
352     // Get a flattened user-provided loop reduction values.
353     SmallVector<Value> userReduc;
354     for (ValueRange r : adaptor.getInitArgs())
355       llvm::append_range(userReduc, r);
356 
357     // TODO: we need to sort the cases such that they appears in lexical order.
358     // Although sparsification always generates cases in that order, it might
359     // not be the case for human-written code.
360 
361     // Generates a loop sequence, one loop per case.
362     for (auto [r, caseBits] :
363          llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) {
364       assert(caseBits.count() > 0 && "Complement space not implemented");
365 
366       // Retrives a vector of pointers to the iterators used in the case.
367       SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits);
368 
369       if (validIters.size() > 1) {
370         auto [loop, loopCrd] =
371             genCoIteration(rewriter, loc, validIters, userReduc,
372                            /*uniIdx=*/nullptr, /*userReducFirst=*/true);
373 
374         // 1st. find all the cases that is a strict subset of the current case
375         // condition, for which we generate one branch per case inside the loop.
376         // The subcases are never empty, it must contains at least the current
377         // region itself.
378         // TODO: these cases should be sorted.
379         SmallVector<Region *> subCases =
380             op.getSubCasesOf(r->getParent()->getRegionNumber());
381         SmallVector<Block *> newBlocks, oldBlocks;
382         for (Region *r : subCases) {
383           newBlocks.push_back(&r->front());
384           oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]);
385         }
386         assert(!subCases.empty());
387 
388         ValueRange res = genCoIterateBranchNest(
389             rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc);
390 
391         SmallVector<Value> nextIterYields(res);
392         // 2nd. foward the loop.
393         for (SparseIterator *it : validIters) {
394           Value cmp = rewriter.create<arith::CmpIOp>(
395               loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
396           it->forwardIf(rewriter, loc, cmp);
397           llvm::append_range(nextIterYields, it->getCursor());
398         }
399         rewriter.create<scf::YieldOp>(loc, nextIterYields);
400 
401         // Exit the loop, relink the iterator SSA value.
402         rewriter.setInsertionPointAfter(loop);
403         ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
404         for (SparseIterator *it : validIters)
405           iterVals = it->linkNewScope(iterVals);
406         assert(iterVals.empty());
407 
408         ValueRange curResult = loop->getResults().take_front(userReduc.size());
409         userReduc.assign(curResult.begin(), curResult.end());
410       } else {
411         // This is a simple iteration loop.
412         assert(caseBits.count() == 1);
413 
414         Block *block = r;
415         ValueRange curResult = genLoopWithIterator(
416             rewriter, loc, validIters.front(), userReduc,
417             /*bodyBuilder=*/
418             [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
419                     SparseIterator *it,
420                     ValueRange reduc) -> SmallVector<Value> {
421               SmallVector<Value> blockArgs(reduc);
422               blockArgs.push_back(it->deref(rewriter, loc));
423               llvm::append_range(blockArgs, it->getCursor());
424 
425               Block *dstBlock = &dstRegion.getBlocks().front();
426               rewriter.inlineBlockBefore(
427                   block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
428               auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
429               SmallVector<Value> result(yield.getResults());
430               rewriter.eraseOp(yield);
431               return result;
432             });
433 
434         userReduc.assign(curResult.begin(), curResult.end());
435       }
436     }
437 
438     rewriter.replaceOp(op, userReduc);
439     return success();
440   }
441 };
442 
443 } // namespace
444 
445 mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
446   addConversion([](Type type) { return type; });
447   addConversion(convertIteratorType);
448   addConversion(convertIterSpaceType);
449 
450   addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
451                               ValueRange inputs, Location loc) -> Value {
452     return builder
453         .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
454         .getResult(0);
455   });
456 }
457 
458 void mlir::populateLowerSparseIterationToSCFPatterns(
459     const TypeConverter &converter, RewritePatternSet &patterns) {
460 
461   IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
462   patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
463                SparseIterateOpConverter, SparseCoIterateOpConverter>(
464       converter, patterns.getContext());
465 }
466