xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (revision 06a65ce500a632048db1058de9ca61072004a640)
1 //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "CodegenUtils.h"
10 #include "IterationGraphSorter.h"
11 
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Utils/Utils.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/AffineExprVisitor.h"
21 #include "mlir/IR/AffineMap.h"
22 
23 using namespace mlir;
24 using namespace mlir::sparse_tensor;
25 
26 namespace {
27 
28 //===----------------------------------------------------------------------===//
29 // File Local Helper classes.
30 //===----------------------------------------------------------------------===//
31 
32 // CRTP to help implementing a rewriter that demaps all its inputs.
33 template <typename SubClass, typename SourceOp>
34 struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
35   using OpRewritePattern<SourceOp>::OpRewritePattern;
36   using OpAdaptor = typename SourceOp::Adaptor;
37 
38   LogicalResult matchAndRewrite(SourceOp op,
39                                 PatternRewriter &rewriter) const override {
40     Location loc = op.getLoc();
41     // Demaps non-trivial inputs.
42     SmallVector<Value> deMappedIns(op->getOperands());
43     for (Value &in : deMappedIns)
44       if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
45         in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
46 
47     // CRTP call.
48     OpAdaptor adaptor(deMappedIns, op);
49     return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
50                                                           rewriter);
51   }
52 };
53 
54 // Flattens an affine expression into a list of AffineDimExprs.
55 struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
56   explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
57   void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
58   BitVector dims;
59 };
60 
61 // Flattens an affine expression into a list of AffineDimExprs.
62 struct AffineExprAdmissibleVisitor
63     : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
64   explicit AffineExprAdmissibleVisitor(bool isOutput)
65       : admissible(true), isOutput(isOutput){};
66 
67   // We only allow AffineDimExpr on output.
68   void visitAddExpr(AffineBinaryOpExpr expr) {
69     if (isOutput)
70       admissible = false;
71   }
72   void visitMulExpr(AffineBinaryOpExpr expr) {
73     if (isOutput)
74       admissible = false;
75   }
76 
77   // We disallow mod, floor div and ceil div  on inputs.
78   void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
79   void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
80   void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
81   operator bool() { return admissible; }
82 
83 private:
84   bool admissible;
85   bool isOutput;
86 };
87 
88 // The first BitVector stores levels where inadmissible exprs are used.
89 // The second BitVector stores the AffineDimExp that are used by the
90 // inadmissible expressions.
91 using InadmissInfo = std::pair<BitVector, BitVector>;
92 
93 } // namespace
94 
95 //===----------------------------------------------------------------------===//
96 // File Local Helper methods.
97 //===----------------------------------------------------------------------===//
98 
99 // Collects the inadmissible affine expression imposed on levels.
100 static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
101   auto ret = std::make_pair(BitVector(map.getNumResults()),
102                             BitVector(map.getNumDims()));
103   AffineDimCollector collector(map.getNumDims());
104   for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
105     AffineExprAdmissibleVisitor admissible(isOutput);
106     admissible.walkPostOrder(map.getResult(lvl));
107     if (!admissible) {
108       // Record the inadmissible level.
109       ret.first.set(lvl);
110       // Record the AffineDimExpr that is used in the inadmissible expr.
111       collector.walkPostOrder(map.getResult(lvl));
112     }
113   }
114   ret.second = collector.dims;
115   return ret;
116 }
117 
118 // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
119 // inadmissible affine expressions can be eliminated.
120 // For example, we can rewrite
121 // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
122 // to
123 // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
124 // by composing inverse(idxMap), that is
125 // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
126 //                         -> ((l0 * 2 + l2) floordiv 2,
127 //                             (l1 * 3 + l3) floordiv 3,
128 //                             (l0 * 2 + l2) mod 2,
129 //                             (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
130 //
131 // This function builds the inverse(idxMap) that replace every dimensions used
132 // in `info` to levels, and updates the iterator type array `itTps` for the new
133 // index variable introduced.
134 //
135 // Note that the returned affine map does not retain the order of the input
136 // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
137 // replaced levels, and remaining ones for unused dimensions.
138 // For example, to handle
139 // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
140 // which is a typical map for block_2to4. The function returns:
141 // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
142 // in which, (l0, l1) together replaces `d1`, yet they appear
143 // before `d0` in the resulting affine map.
144 // The index (loop) order can later be canonicalized by a topo sort.
145 static AffineMap
146 genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
147                       SmallVector<utils::IteratorType> &itTps) {
148   MLIRContext *ctx = idxMap.getContext();
149   auto [inAdLvls, usedDims] = info;
150   // Note that idxMap does not equal to dim2Lvl map, it is computed by
151   // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
152   // ID map.
153   // TODO: we might fail here, in those case we should really return
154   // failure instead of assertion error.
155   auto lvl2Idx = inferLvlToDim(idxMap, ctx);
156 
157   assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
158   if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
159     // This could happen when some dimensions are projected.
160     // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
161     //   ==> lvl2Idx = (j, k) -> (j, k)
162     // In this case, we append the unused dimesion at the end.
163     //   ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
164     SmallVector<AffineExpr> results;
165     AffineDimCollector usedInLvl(idxMap.getNumDims());
166     for (auto e : idxMap.getResults())
167       usedInLvl.walkPostOrder(e);
168 
169     unsigned curUsedDimID = 0;
170     unsigned curUnusedDimID = lvl2Idx.getNumDims();
171 
172     BitVector unused = usedInLvl.dims.flip();
173     for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
174       if (unused.test(i))
175         results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
176       else
177         results.push_back(lvl2Idx.getResult(curUsedDimID++));
178     }
179     lvl2Idx =
180         AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
181   }
182   assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
183 
184   // We do not need to replace the DimExpr that is not used in inadmissible
185   // level expressions. We use the first inAdLvl.count() dim to represent the
186   // replaced level, the remainings are reserved for unchanged ones.
187   // Note that results from the inverse map computed previously does not follow
188   // the convention we used, and we need to fix the mismatch below.
189   unsigned curRepID = 0;
190   unsigned curOriID = inAdLvls.count();
191   SmallVector<AffineExpr> results;
192   SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr());
193   SmallVector<utils::IteratorType> transItTps;
194 
195   for (unsigned l : inAdLvls.set_bits()) {
196     // By our convention, the inadmissible level `l` always appears in the
197     // leading part (accumulated by curRepID) of the affine map's parameter
198     // list. Record the mapping so that we can replace all the uses of `l` to
199     // the correct position after the translation.
200     dimRep[l] = getAffineDimExpr(curRepID++, ctx);
201     // A new index variable is introduced for the inadmissible level, inherit
202     // the iterator type. E.g., if l0 = d0 floordiv 2, the
203     // iterator type of l0 equals to the iterator type of d0.
204     AffineExpr lvlExp = idxMap.getResult(l);
205     AffineDimCollector collector(idxMap.getNumDims());
206     collector.walkPostOrder(lvlExp);
207     // We assumes a level can only be derived from one dimension.
208     assert(collector.dims.count() == 1);
209     transItTps.push_back(itTps[collector.dims.find_first()]);
210   }
211 
212   for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
213     if (usedDims.test(d)) {
214       // The dimension is used in some of the inadmissible levels, and it need
215       // to be inversed. Get the inversion from the inverse map, and fix the
216       // mismatch captured by the above loop.
217       results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
218     } else {
219       // The dimension is not used in any of the inadmissible levels, and it
220       // does not need to be inversed. Fix the mismatch by mapping it to the
221       // trailing part of the affine map (accumulated by curOriID).
222       results.push_back(getAffineDimExpr(curOriID++, ctx));
223       transItTps.push_back(itTps[d]);
224     }
225   }
226   unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
227   // Update iterator type.
228   itTps.assign(transItTps.begin(), transItTps.end());
229   return AffineMap::get(numDim, 0, results, ctx);
230 }
231 
232 // Translates the index map in the linalg::GenericOp from idx->dim map to
233 // idx->lvl map. Returns failure if the index map can not be translated to an
234 // admissible form.
235 // Returns the translated index map array and the iterator type array.
236 static std::optional<std::pair<ArrayAttr, ArrayAttr>>
237 translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
238   // idxMap is a idx2dim map before reinterpretation.
239   MLIRContext *ctx = op.getContext();
240   SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
241   SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
242   for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
243     Value tensor = op->getOpOperand(i).get();
244     auto stt = tryGetSparseTensorType(tensor);
245     if (stt && !stt->isIdentity()) {
246       AffineMap dim2Lvl = stt->getDimToLvl();
247       // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
248       idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
249     }
250   }
251 
252   // A naive way to handle common constant expressions that arise during dim2lvl
253   // translation.
254   auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
255                                   unsigned pos, int64_t lvlSz) {
256     if (!ShapedType::isDynamic(lvlSz)) {
257       auto c0 = getAffineConstantExpr(0, ctx);
258       auto lvlExp = getAffineDimExpr(pos, ctx);
259       auto szExp = getAffineConstantExpr(lvlSz, ctx);
260 
261       // lvl floordiv lvlSz = 0
262       auto divExp =
263           getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp);
264       cstMapping.try_emplace(divExp, c0);
265 
266       // lvl mod lvlSz = lvl
267       auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
268       cstMapping.try_emplace(modExp, lvlExp);
269     }
270   };
271 
272   unsigned boundedNum = 0;
273   // A fixed-point algorithm.
274   bool changed = true;
275   while (changed) {
276     changed = false;
277     for (OpOperand &operand : op->getOpOperands()) {
278       auto stt = tryGetSparseTensorType(operand.get());
279       // Skip on dense operands.
280       if (!stt || !stt->getEncoding())
281         continue;
282 
283       unsigned tid = operand.getOperandNumber();
284       bool isOutput = &operand == op.getDpsInitOperand(0);
285       AffineMap idxMap = idxMapArray[tid];
286       InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
287       auto [inAdLvls, dimExprs] = inAdInfo;
288       for (unsigned d : dimExprs.set_bits()) {
289         // The first `boundedNum` used in the AffineMap is introduced to
290         // resolve previous inadmissible expressions. We can not replace them
291         // as it might bring back the inadmissible expressions.
292         if (d < boundedNum)
293           return std::nullopt;
294       }
295 
296       if (inAdLvls.count() != 0) {
297         // Naive constant progagation, should be sufficient to handle block
298         // sparsity in our cases.
299         SmallVector<int64_t> lvlShape = stt->getLvlShape();
300         DenseMap<AffineExpr, AffineExpr> cstMapping;
301         unsigned position = 0;
302         for (unsigned lvl : inAdLvls.set_bits()) {
303           int64_t lvlSz = lvlShape[lvl];
304           populateCstMapping(cstMapping, position, lvlSz);
305           position++;
306         }
307 
308         AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
309         // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
310         // inadmissible expressions.
311         for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
312           AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
313           idxMapArray[tid] = transMap.replace(
314               cstMapping, /*numResultDims=*/transMap.getNumDims(),
315               /*numResultSyms=*/0);
316         }
317         changed = true;
318         boundedNum += inAdLvls.count();
319       }
320     }
321   };
322 
323   SmallVector<Attribute> iterAttr =
324       llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
325         return linalg::IteratorTypeAttr::get(ctx, itTp);
326       });
327 
328   return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
329                         rewriter.getArrayAttr(iterAttr));
330 }
331 
332 // Generates a "de"mapping reinterpretation of the map.
333 static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
334                       Value val) {
335   return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
336                                           val);
337 }
338 
339 // Generates a "re"mapping reinterpretation of the map.
340 static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
341                       Value val) {
342   return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
343 }
344 
345 static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
346                                           ValueRange outs) {
347   SmallVector<Value> ret(outs);
348   assert(outs.size() == types.size());
349   for (auto [r, t] : llvm::zip(ret, types))
350     if (r.getType() != t)
351       r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
352   return ret;
353 }
354 
355 namespace {
356 
357 //===----------------------------------------------------------------------===//
358 // Rewriting rules for linalg generic ops.
359 //===----------------------------------------------------------------------===//
360 
361 /// Sparse rewriting rule for the generic `linalg` operation.
362 struct GenericOpReinterpretMap
363     : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
364 public:
365   using DemapInsRewriter::DemapInsRewriter;
366   LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
367                           PatternRewriter &rewriter) const {
368     // Only rewrite single output operations with pure (sparse) tensor
369     // semantics.
370     if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
371         !hasAnySparseOperandOrResult(linalgOp) ||
372         !hasAnyNonIdentityOperandsOrResults(linalgOp))
373       return failure();
374 
375     // Try translating the index map.
376     auto transMap = translateMap(linalgOp, rewriter);
377     if (!transMap)
378       return rewriter.notifyMatchFailure(
379           linalgOp, "the sparse kernel can not be sparsified.");
380 
381     // On success, replace update the linalg operands and maps in place.
382     Value res = linalgOp.getResult(0);
383     auto stt = tryGetSparseTensorType(res);
384     auto [idxMap, itTp] = *transMap;
385 
386     rewriter.startRootUpdate(linalgOp);
387     linalgOp.setIndexingMapsAttr(idxMap);
388     linalgOp.setIteratorTypesAttr(itTp);
389     // Use demapped arguments.
390     linalgOp.getInputsMutable().assign(adaptor.getInputs());
391     linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
392     res.setType(adaptor.getOutputs()[0].getType());
393     rewriter.finalizeRootUpdate(linalgOp);
394 
395     rewriter.setInsertionPointAfter(linalgOp);
396     if (stt && stt->hasEncoding()) {
397       Value t = genRemap(rewriter, stt->getEncoding(), res);
398       rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
399     }
400     return success();
401   }
402 };
403 
404 struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
405   using OpRewritePattern::OpRewritePattern;
406   LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
407                                 PatternRewriter &rewriter) const override {
408     if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
409         hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
410         !hasAnySparseOperandOrResult(linalgOp)) {
411       return failure();
412     }
413 
414     const StringRef sorted = "sorted";
415     if (linalgOp->hasAttr(sorted))
416       return failure();
417 
418     auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
419     bool isAdmissible = false;
420     AffineMap order;
421     // A const list of all masks that we used for iteration graph
422     // computation. Must be ordered from more strict to less strict.
423     // Ideally (though might not be guaranteed), the earlier a constraint mask
424     // can be satisfied, the faster the generated kernel will be.
425     const auto allMasks = {
426         SortMask::kIncludeAll,        SortMask::kIncludeDense,
427         SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
428         SortMask::kIncludeUndef,      SortMask::kSparseOnly};
429     for (const SortMask mask : allMasks) {
430       order = scheduler.sort(mask);
431       if (order) {
432         if (isAdmissibleOrder(linalgOp, order)) {
433           isAdmissible = true;
434           break;
435         }
436         // else try a set of less strict constraints.
437       }
438     }
439 
440     if (!order) {
441       // Cycles detected.
442       if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
443         return rewriter.notifyMatchFailure(
444             linalgOp, "the sparse kernel can not be scheduled: loop detected.");
445       }
446       return success();
447     }
448 
449     if (!isAdmissible) {
450       return rewriter.notifyMatchFailure(
451           linalgOp, "the sparse kernel can not be scheduled.");
452     }
453 
454     // Marks the GenericOp to avoid recursive matching.
455     linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
456 
457     // Already sorted.
458     if (order.isIdentity())
459       return failure();
460 
461     assert(order.isPermutation());
462     // `order` is orignial loop -> sorted loop map
463     ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
464     SmallVector<Attribute> curItTypes;
465     curItTypes.reserve(preItTypes.size());
466     for (AffineExpr expr : order.getResults()) {
467       unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
468       curItTypes.push_back(preItTypes[loopID]);
469     }
470 
471     // Inverse `order` to get sorted loop -> original loop map
472     order = inversePermutation(order);
473     SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
474     for (AffineMap &idxMap : idxMaps)
475       idxMap = idxMap.compose(order); // sorted loop -> lvl map
476 
477     rewriter.startRootUpdate(linalgOp);
478     linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
479     linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
480     rewriter.finalizeRootUpdate(linalgOp);
481 
482     return success();
483   }
484 
485 private:
486   /// Whether the loop order is admissible by sparsification.
487   static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
488     if (!hasAnySparseResult(linalgOp))
489       return true;
490 
491     OpOperand *lhs = linalgOp.getDpsInitOperand(0);
492     unsigned nest = 0;
493     const auto iteratorTypes = linalgOp.getIteratorTypesArray();
494     for (const AffineExpr l : order.getResults()) {
495       unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
496       auto itTp =
497           linalgOp.getIteratorTypes()[loopId].cast<linalg::IteratorTypeAttr>();
498       if (linalg::isReductionIterator(itTp.getValue()))
499         break; // terminate at first reduction
500       nest++;
501     }
502     // Determine admissible dynamic insertion situations:
503     // (1) fully injective, since there are no reductions,
504     // (2) admissible 1-d expansion in innermost dimension.
505     return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
506   };
507 
508   // Last resort cycle resolution.
509   static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
510                                     linalg::LinalgOp linalgOp,
511                                     PatternRewriter &rewriter) {
512     // Compute topological sort while leaving out every sparse input tensor in
513     // succession until an acylic iteration graph results.
514     for (OpOperand *t : linalgOp.getDpsInputOperands()) {
515       Value tval = t->get();
516       auto srcEnc = getSparseTensorEncoding(tval.getType());
517       // The constraints introduced by compound index expression are
518       // complicated. Skip them.
519       AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
520       bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
521         return !llvm::isa<AffineDimExpr>(exp);
522       });
523       if (!srcEnc || hasCompExpr)
524         continue;
525 
526       // Try scheduling loop without constraints from `tval`.
527       AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
528       if (!order) // still cyclic
529         continue;
530 
531       // Found an input tensor that resolves the cycle by inserting a
532       // conversion into a sparse tensor that adheres to the iteration
533       // graph order.
534       auto stt = getSparseTensorType(tval);
535       assert(stt.isIdentity());
536       order = inversePermutation(order);
537       // sorted loop -> lvl map.
538       idxMap = idxMap.compose(order);
539 
540       // Found a permutation such that the results in `idxMap` is sorted.
541       // For example,
542       //  (d0, d1, d2, d3) -> (d2, d1, d0)
543       // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
544       // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
545       // transposed tensor's levels are visited in the same order as the loop
546       // scheduling order.
547       SmallVector<std::pair<unsigned, unsigned>> lvlSeq;
548       for (AffineExpr expr : idxMap.getResults()) {
549         unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
550         lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
551       }
552       std::sort(lvlSeq.begin(), lvlSeq.end(), [](auto &lhs, auto &rhs) -> bool {
553         return lhs.first < rhs.first;
554       });
555       SmallVector<unsigned> perm =
556           llvm::to_vector(llvm::make_second_range(lvlSeq));
557       auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
558       // The result of the idxMap must be unsorted.
559       assert(!dimToLvl.isIdentity());
560 
561       // Inserting the transpose
562       rewriter.setInsertionPoint(linalgOp);
563       RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
564       Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
565       rewriter.updateRootInPlace(linalgOp, [&]() {
566         linalgOp->setOperand(t->getOperandNumber(), dst);
567       });
568       return success();
569     }
570     // Cannot be resolved with a single conversion.
571     // TODO: convert more than one?
572     return failure();
573   }
574 };
575 
576 //===----------------------------------------------------------------------===//
577 // Reinterpret Map Rewriters for operations other than linalg.generics
578 //===----------------------------------------------------------------------===//
579 
580 template <typename AllocOp>
581 struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
582   using OpRewritePattern<AllocOp>::OpRewritePattern;
583   LogicalResult matchAndRewrite(AllocOp op,
584                                 PatternRewriter &rewriter) const override {
585     if (!hasAnyNonIdentityOperandsOrResults(op))
586       return failure();
587 
588     Location loc = op.getLoc();
589     auto stt = getSparseTensorType(op.getResult());
590 
591     SmallVector<Value> maxDimCrds;
592     maxDimCrds.reserve(stt.getDimRank());
593     ValueRange dynSz = op.getDynamicSizes();
594     for (int64_t dimSz : stt.getDimShape()) {
595       if (ShapedType::isDynamic(dimSz)) {
596         Value maxCrd = rewriter.create<arith::SubIOp>(
597             loc, dynSz.front(), constantIndex(rewriter, loc, 1));
598         maxDimCrds.push_back(maxCrd);
599         dynSz = dynSz.drop_front();
600       } else {
601         maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
602       }
603     }
604 
605     ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
606                                               CrdTransDirectionKind::dim2lvl);
607     auto lvlShape = stt.getLvlShape();
608     SmallVector<Value> dynLvlSzs;
609     for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
610       if (ShapedType::isDynamic(lvlShape[i])) {
611         Value sz = rewriter.create<arith::AddIOp>(
612             loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
613         dynLvlSzs.push_back(sz);
614       }
615     }
616 
617     assert(dynSz.empty()); // should have consumed all.
618     rewriter.startRootUpdate(op);
619     op->setOperands(dynLvlSzs);
620     op.getResult().setType(stt.getDemappedType());
621     rewriter.finalizeRootUpdate(op);
622     rewriter.setInsertionPointAfter(op);
623 
624     Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
625     rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
626     return success();
627   }
628 };
629 
630 struct TensorInsertDemapper
631     : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
632   using DemapInsRewriter::DemapInsRewriter;
633   LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
634                           PatternRewriter &rewriter) const {
635     if (!hasAnySparseResult(op))
636       return failure();
637 
638     Location loc = op.getLoc();
639     auto stt = getSparseTensorType(op.getResult());
640     ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
641                                           CrdTransDirectionKind::dim2lvl);
642     auto insertOp = rewriter.create<sparse_tensor::InsertOp>(
643         loc, op.getScalar(), adaptor.getDest(), lvlCrd);
644 
645     Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
646     rewriter.replaceOp(op, out);
647     return success();
648   }
649 };
650 
651 struct ForeachOpDemapper
652     : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
653   using DemapInsRewriter::DemapInsRewriter;
654   LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
655                           PatternRewriter &rewriter) const {
656     // Only handle operations with sparse input/output with non-identity dim2lvl
657     // maps.
658     if (!hasAnyNonIdentityOperandsOrResults(op))
659       return failure();
660 
661     // TODO: demap constant as well.
662     if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
663       if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
664         return failure();
665 
666     Location loc = op.getLoc();
667     // Cache the type information since we update the foreach op in-place.
668     auto srcStt = getSparseTensorType(op.getTensor());
669     SmallVector<Type> prevRetTps(op.getResultTypes());
670 
671     rewriter.startRootUpdate(op);
672     op.getTensorMutable().assign(adaptor.getTensor());
673     op.getInitArgsMutable().assign(adaptor.getInitArgs());
674     // Update results' types.
675     for (auto r : op.getResults())
676       if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
677         r.setType(stt->getDemappedType());
678 
679     Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
680     // Update the foreach body.
681     SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
682     blockArgTps.push_back(srcStt.getElementType());
683     blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
684                        adaptor.getInitArgs().getTypes().end());
685     Block *body = op.getBody();
686     // Block Args: [dimCrd, val, initArgs]
687     unsigned preArgNum = body->getNumArguments();
688     for (Type t : blockArgTps)
689       body->addArgument(t, loc);
690 
691     // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
692     rewriter.setInsertionPointToStart(body);
693     ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
694 
695     ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
696                                               CrdTransDirectionKind::lvl2dim);
697     rewriter.replaceAllUsesWith(
698         body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
699     body->eraseArguments(0, srcStt.getDimRank());
700     // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
701     unsigned numInitArgs = op.getInitArgs().size();
702     rewriter.replaceAllUsesWith(body->getArgument(0),
703                                 body->getArgument(lvlRank + numInitArgs + 1));
704     body->eraseArgument(0);
705     // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
706     ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
707     ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
708     // Remap back before replacement.
709     SmallVector<Value> reMappedArgs =
710         remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
711     rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
712     body->eraseArguments(0, numInitArgs);
713     // Block Args: [lvlCrds, DemappedArgs] and we are done.
714 
715     // Update yield operations.
716     if (numInitArgs != 0) {
717       rewriter.setInsertionPointToEnd(body);
718       auto yield = llvm::cast<YieldOp>(body->getTerminator());
719       if (auto stt = tryGetSparseTensorType(yield.getResult());
720           stt && !stt->isIdentity()) {
721         Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult());
722         rewriter.create<YieldOp>(loc, y);
723         rewriter.eraseOp(yield);
724       }
725     }
726     rewriter.finalizeRootUpdate(op);
727 
728     rewriter.setInsertionPointAfter(op);
729     SmallVector<Value> outs =
730         remapValueRange(rewriter, prevRetTps, op.getResults());
731 
732     // Replace all the uses of the foreach results, expect the use in
733     // reinterpret_map used to remap the output.
734     for (auto [from, to] : llvm::zip(op.getResults(), outs))
735       rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
736 
737     return success();
738   }
739 };
740 
741 } // namespace
742 
743 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
744                                         ReinterpretMapScope scope) {
745   if (scope == ReinterpretMapScope::kAll ||
746       scope == ReinterpretMapScope::kGenericOnly) {
747     patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
748         patterns.getContext());
749   }
750   if (scope == ReinterpretMapScope::kAll ||
751       scope == ReinterpretMapScope::kExceptGeneric) {
752     patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
753                  TensorAllocDemapper<tensor::EmptyOp>, TensorInsertDemapper,
754                  ForeachOpDemapper>(patterns.getContext());
755   }
756 }
757