xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (revision 94e27c265a9aeb3659175ecee81a68d1763e0180)
1 //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Utils/CodegenUtils.h"
10 #include "Utils/IterationGraphSorter.h"
11 
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Utils/Utils.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/AffineExprVisitor.h"
21 #include "mlir/IR/AffineMap.h"
22 
23 using namespace mlir;
24 using namespace mlir::sparse_tensor;
25 
26 namespace {
27 
28 //===----------------------------------------------------------------------===//
29 // File Local Helper classes.
30 //===----------------------------------------------------------------------===//
31 
32 // CRTP to help implementing a rewriter that demaps all its inputs.
33 template <typename SubClass, typename SourceOp>
34 struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
35   using OpRewritePattern<SourceOp>::OpRewritePattern;
36   using OpAdaptor = typename SourceOp::Adaptor;
37 
38   LogicalResult matchAndRewrite(SourceOp op,
39                                 PatternRewriter &rewriter) const override {
40     Location loc = op.getLoc();
41 
42     // Demaps non-trivial inputs.
43     bool changed = false;
44     SmallVector<Value> deMappedIns(op->getOperands());
45     for (Value &in : deMappedIns) {
46       if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
47         in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
48         changed = true;
49       }
50     }
51 
52     // CRTP call.
53     OpAdaptor adaptor(deMappedIns, op);
54     LogicalResult status =
55         static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
56     return changed ? success() : status;
57   }
58 };
59 
60 // Flattens an affine expression into a list of AffineDimExprs.
61 struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
62   explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
63   void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
64   BitVector dims;
65 };
66 
67 // Flattens an affine expression into a list of AffineDimExprs.
68 struct AffineExprAdmissibleVisitor
69     : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
70   explicit AffineExprAdmissibleVisitor(bool isOutput)
71       : admissible(true), isOutput(isOutput){};
72 
73   // We only allow AffineDimExpr on output.
74   void visitAddExpr(AffineBinaryOpExpr expr) {
75     if (isOutput)
76       admissible = false;
77   }
78   void visitMulExpr(AffineBinaryOpExpr expr) {
79     if (isOutput)
80       admissible = false;
81   }
82 
83   // We disallow mod, floor div and ceil div  on inputs.
84   void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
85   void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
86   void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
87   operator bool() { return admissible; }
88 
89 private:
90   bool admissible;
91   bool isOutput;
92 };
93 
94 // The first BitVector stores levels where inadmissible exprs are used.
95 // The second BitVector stores the AffineDimExp that are used by the
96 // inadmissible expressions.
97 using InadmissInfo = std::pair<BitVector, BitVector>;
98 
99 } // namespace
100 
101 //===----------------------------------------------------------------------===//
102 // File Local Helper methods.
103 //===----------------------------------------------------------------------===//
104 
105 // Collects the inadmissible affine expression imposed on levels.
106 static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
107   auto ret = std::make_pair(BitVector(map.getNumResults()),
108                             BitVector(map.getNumDims()));
109   AffineDimCollector collector(map.getNumDims());
110   for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
111     AffineExprAdmissibleVisitor admissible(isOutput);
112     admissible.walkPostOrder(map.getResult(lvl));
113     if (!admissible) {
114       // Record the inadmissible level.
115       ret.first.set(lvl);
116       // Record the AffineDimExpr that is used in the inadmissible expr.
117       collector.walkPostOrder(map.getResult(lvl));
118     }
119   }
120   ret.second = collector.dims;
121   return ret;
122 }
123 
124 // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
125 // inadmissible affine expressions can be eliminated.
126 // For example, we can rewrite
127 // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
128 // to
129 // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
130 // by composing inverse(idxMap), that is
131 // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
132 //                         -> ((l0 * 2 + l2) floordiv 2,
133 //                             (l1 * 3 + l3) floordiv 3,
134 //                             (l0 * 2 + l2) mod 2,
135 //                             (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
136 //
137 // This function builds the inverse(idxMap) that replace every dimensions used
138 // in `info` to levels, and updates the iterator type array `itTps` for the new
139 // index variable introduced.
140 //
141 // Note that the returned affine map does not retain the order of the input
142 // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
143 // replaced levels, and remaining ones for unused dimensions.
144 // For example, to handle
145 // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
146 // which is a typical map for block_2to4. The function returns:
147 // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
148 // in which, (l0, l1) together replaces `d1`, yet they appear
149 // before `d0` in the resulting affine map.
150 // The index (loop) order can later be canonicalized by a topo sort.
151 static AffineMap
152 genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
153                       SmallVector<utils::IteratorType> &itTps) {
154   MLIRContext *ctx = idxMap.getContext();
155   auto [inAdLvls, usedDims] = info;
156   // Note that idxMap does not equal to dim2Lvl map, it is computed by
157   // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
158   // ID map.
159   // TODO: we might fail here, in those case we should really return
160   // failure instead of assertion error.
161   auto lvl2Idx = inferLvlToDim(idxMap, ctx);
162 
163   assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
164   if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
165     // This could happen when some dimensions are projected.
166     // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
167     //   ==> lvl2Idx = (j, k) -> (j, k)
168     // In this case, we append the unused dimesion at the end.
169     //   ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
170     SmallVector<AffineExpr> results;
171     AffineDimCollector usedInLvl(idxMap.getNumDims());
172     for (auto e : idxMap.getResults())
173       usedInLvl.walkPostOrder(e);
174 
175     unsigned curUsedDimID = 0;
176     unsigned curUnusedDimID = lvl2Idx.getNumDims();
177 
178     BitVector unused = usedInLvl.dims.flip();
179     for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
180       if (unused.test(i))
181         results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
182       else
183         results.push_back(lvl2Idx.getResult(curUsedDimID++));
184     }
185     lvl2Idx =
186         AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
187   }
188   assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
189 
190   // We do not need to replace the DimExpr that is not used in inadmissible
191   // level expressions. We use the first inAdLvl.count() dim to represent the
192   // replaced level, the remainings are reserved for unchanged ones.
193   // Note that results from the inverse map computed previously does not follow
194   // the convention we used, and we need to fix the mismatch below.
195   unsigned curRepID = 0;
196   unsigned curOriID = inAdLvls.count();
197   SmallVector<AffineExpr> results;
198   SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr());
199   SmallVector<utils::IteratorType> transItTps;
200 
201   for (unsigned l : inAdLvls.set_bits()) {
202     // By our convention, the inadmissible level `l` always appears in the
203     // leading part (accumulated by curRepID) of the affine map's parameter
204     // list. Record the mapping so that we can replace all the uses of `l` to
205     // the correct position after the translation.
206     dimRep[l] = getAffineDimExpr(curRepID++, ctx);
207     // A new index variable is introduced for the inadmissible level, inherit
208     // the iterator type. E.g., if l0 = d0 floordiv 2, the
209     // iterator type of l0 equals to the iterator type of d0.
210     AffineExpr lvlExp = idxMap.getResult(l);
211     AffineDimCollector collector(idxMap.getNumDims());
212     collector.walkPostOrder(lvlExp);
213     // We assumes a level can only be derived from one dimension.
214     assert(collector.dims.count() == 1);
215     transItTps.push_back(itTps[collector.dims.find_first()]);
216   }
217 
218   for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
219     if (usedDims.test(d)) {
220       // The dimension is used in some of the inadmissible levels, and it need
221       // to be inversed. Get the inversion from the inverse map, and fix the
222       // mismatch captured by the above loop.
223       results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
224     } else {
225       // The dimension is not used in any of the inadmissible levels, and it
226       // does not need to be inversed. Fix the mismatch by mapping it to the
227       // trailing part of the affine map (accumulated by curOriID).
228       results.push_back(getAffineDimExpr(curOriID++, ctx));
229       transItTps.push_back(itTps[d]);
230     }
231   }
232   unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
233   // Update iterator type.
234   itTps.assign(transItTps.begin(), transItTps.end());
235   return AffineMap::get(numDim, 0, results, ctx);
236 }
237 
238 // Translates the index map in the linalg::GenericOp from idx->dim map to
239 // idx->lvl map. Returns failure if the index map can not be translated to an
240 // admissible form.
241 // Returns the translated index map array and the iterator type array.
242 static std::optional<std::pair<ArrayAttr, ArrayAttr>>
243 translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
244   // idxMap is a idx2dim map before reinterpretation.
245   MLIRContext *ctx = op.getContext();
246   SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
247   SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
248   for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
249     Value tensor = op->getOpOperand(i).get();
250     auto stt = tryGetSparseTensorType(tensor);
251     if (stt && !stt->isIdentity()) {
252       AffineMap dim2Lvl = stt->getDimToLvl();
253       // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
254       idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
255     }
256   }
257 
258   // A naive way to handle common constant expressions that arise during dim2lvl
259   // translation.
260   auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
261                                   unsigned pos, int64_t lvlSz) {
262     if (!ShapedType::isDynamic(lvlSz)) {
263       auto c0 = getAffineConstantExpr(0, ctx);
264       auto lvlExp = getAffineDimExpr(pos, ctx);
265       auto szExp = getAffineConstantExpr(lvlSz, ctx);
266 
267       // lvl floordiv lvlSz = 0
268       auto divExp =
269           getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp);
270       cstMapping.try_emplace(divExp, c0);
271 
272       // lvl mod lvlSz = lvl
273       auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
274       cstMapping.try_emplace(modExp, lvlExp);
275     }
276   };
277 
278   unsigned boundedNum = 0;
279   // A fixed-point algorithm.
280   bool changed = true;
281   while (changed) {
282     changed = false;
283     for (OpOperand &operand : op->getOpOperands()) {
284       auto stt = tryGetSparseTensorType(operand.get());
285       // Skip on dense operands.
286       if (!stt || !stt->getEncoding())
287         continue;
288 
289       unsigned tid = operand.getOperandNumber();
290       bool isOutput = &operand == op.getDpsInitOperand(0);
291       AffineMap idxMap = idxMapArray[tid];
292       InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
293       auto [inAdLvls, dimExprs] = inAdInfo;
294       for (unsigned d : dimExprs.set_bits()) {
295         // The first `boundedNum` used in the AffineMap is introduced to
296         // resolve previous inadmissible expressions. We can not replace them
297         // as it might bring back the inadmissible expressions.
298         if (d < boundedNum)
299           return std::nullopt;
300       }
301 
302       if (inAdLvls.count() != 0) {
303         // Naive constant progagation, should be sufficient to handle block
304         // sparsity in our cases.
305         SmallVector<int64_t> lvlShape = stt->getLvlShape();
306         DenseMap<AffineExpr, AffineExpr> cstMapping;
307         unsigned position = 0;
308         for (unsigned lvl : inAdLvls.set_bits()) {
309           int64_t lvlSz = lvlShape[lvl];
310           populateCstMapping(cstMapping, position, lvlSz);
311           position++;
312         }
313 
314         AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
315         // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
316         // inadmissible expressions.
317         for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
318           AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
319           idxMapArray[tid] = transMap.replace(
320               cstMapping, /*numResultDims=*/transMap.getNumDims(),
321               /*numResultSyms=*/0);
322         }
323         changed = true;
324         boundedNum += inAdLvls.count();
325       }
326     }
327   };
328 
329   SmallVector<Attribute> iterAttr =
330       llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
331         return linalg::IteratorTypeAttr::get(ctx, itTp);
332       });
333 
334   return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
335                         rewriter.getArrayAttr(iterAttr));
336 }
337 
338 // Generates a "de"mapping reinterpretation of the map.
339 static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
340                       Value val) {
341   return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
342                                           val);
343 }
344 
345 // Generates a "re"mapping reinterpretation of the map.
346 static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
347                       Value val) {
348   return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
349 }
350 
351 static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
352                                           ValueRange outs) {
353   SmallVector<Value> ret(outs);
354   assert(outs.size() == types.size());
355   for (auto [r, t] : llvm::zip(ret, types))
356     if (r.getType() != t)
357       r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
358   return ret;
359 }
360 
361 namespace {
362 
363 //===----------------------------------------------------------------------===//
364 // Rewriting rules for linalg generic ops.
365 //===----------------------------------------------------------------------===//
366 
367 /// Sparse rewriting rule for the generic `linalg` operation.
368 struct GenericOpReinterpretMap
369     : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
370 public:
371   using DemapInsRewriter::DemapInsRewriter;
372   LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
373                           PatternRewriter &rewriter) const {
374     // Only rewrite single output operations with pure (sparse) tensor
375     // semantics.
376     if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
377         !hasAnySparseOperandOrResult(linalgOp) ||
378         !hasAnyNonIdentityOperandsOrResults(linalgOp))
379       return failure();
380 
381     // Try translating the index map.
382     auto transMap = translateMap(linalgOp, rewriter);
383     if (!transMap)
384       return rewriter.notifyMatchFailure(
385           linalgOp, "the sparse kernel can not be sparsified.");
386 
387     // On success, replace update the linalg operands and maps in place.
388     Value res = linalgOp.getResult(0);
389     auto stt = tryGetSparseTensorType(res);
390     auto [idxMap, itTp] = *transMap;
391 
392     rewriter.startOpModification(linalgOp);
393     linalgOp.setIndexingMapsAttr(idxMap);
394     linalgOp.setIteratorTypesAttr(itTp);
395     // Use demapped arguments.
396     linalgOp.getInputsMutable().assign(adaptor.getInputs());
397     linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
398     res.setType(adaptor.getOutputs()[0].getType());
399     rewriter.finalizeOpModification(linalgOp);
400 
401     rewriter.setInsertionPointAfter(linalgOp);
402     if (stt && stt->hasEncoding()) {
403       Value t = genRemap(rewriter, stt->getEncoding(), res);
404       rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
405     }
406     return success();
407   }
408 };
409 
410 struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
411   using OpRewritePattern::OpRewritePattern;
412   LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
413                                 PatternRewriter &rewriter) const override {
414     if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
415         hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
416         !hasAnySparseOperandOrResult(linalgOp)) {
417       return failure();
418     }
419 
420     const StringRef sorted = "sorted";
421     if (linalgOp->hasAttr(sorted))
422       return failure();
423 
424     auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
425     bool isAdmissible = false;
426     AffineMap order;
427     // A const list of all masks that we used for iteration graph
428     // computation. Must be ordered from more strict to less strict.
429     // Ideally (though might not be guaranteed), the earlier a constraint mask
430     // can be satisfied, the faster the generated kernel will be.
431     const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
432                            SortMask::kIncludeDenseInput,
433                            SortMask::kIncludeDenseOutput,
434                            SortMask::kSparseOnly};
435     for (const SortMask mask : allMasks) {
436       order = scheduler.sort(mask);
437       if (order) {
438         if (isAdmissibleOrder(linalgOp, order)) {
439           isAdmissible = true;
440           break;
441         }
442         // else try a set of less strict constraints.
443       }
444     }
445 
446     if (!order) {
447       // Cycles detected.
448       if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
449         return rewriter.notifyMatchFailure(
450             linalgOp, "the sparse kernel can not be scheduled: loop detected.");
451       }
452       return success();
453     }
454 
455     if (!isAdmissible) {
456       return rewriter.notifyMatchFailure(
457           linalgOp, "the sparse kernel can not be scheduled.");
458     }
459 
460     // Marks the GenericOp to avoid recursive matching.
461     rewriter.modifyOpInPlace(linalgOp, [&]() {
462       linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
463     });
464 
465     // Already sorted.
466     if (order.isIdentity())
467       return success();
468 
469     assert(order.isPermutation());
470     // `order` is orignial loop -> sorted loop map
471     ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
472     SmallVector<Attribute> curItTypes;
473     curItTypes.reserve(preItTypes.size());
474     for (AffineExpr expr : order.getResults()) {
475       unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
476       curItTypes.push_back(preItTypes[loopID]);
477     }
478 
479     // Inverse `order` to get sorted loop -> original loop map
480     order = inversePermutation(order);
481     SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
482     for (AffineMap &idxMap : idxMaps)
483       idxMap = idxMap.compose(order); // sorted loop -> lvl map
484 
485     rewriter.startOpModification(linalgOp);
486     linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
487     linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
488     rewriter.finalizeOpModification(linalgOp);
489 
490     return success();
491   }
492 
493 private:
494   /// Whether the loop order is admissible by sparsification.
495   static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
496     if (!hasAnySparseResult(linalgOp))
497       return true;
498 
499     OpOperand *lhs = linalgOp.getDpsInitOperand(0);
500     unsigned nest = 0;
501     const auto iteratorTypes = linalgOp.getIteratorTypesArray();
502     for (const AffineExpr l : order.getResults()) {
503       unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
504       auto itTp =
505           linalgOp.getIteratorTypes()[loopId].cast<linalg::IteratorTypeAttr>();
506       if (linalg::isReductionIterator(itTp.getValue()))
507         break; // terminate at first reduction
508       nest++;
509     }
510     // Determine admissible dynamic insertion situations:
511     // (1) fully injective, since there are no reductions,
512     // (2) admissible 1-d expansion in innermost dimension.
513     return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
514   };
515 
516   // Last resort cycle resolution.
517   static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
518                                     linalg::LinalgOp linalgOp,
519                                     PatternRewriter &rewriter) {
520     // Compute topological sort while leaving out every sparse input tensor in
521     // succession until an acylic iteration graph results.
522     for (OpOperand *t : linalgOp.getDpsInputOperands()) {
523       Value tval = t->get();
524       auto srcEnc = getSparseTensorEncoding(tval.getType());
525       // The constraints introduced by compound index expression are
526       // complicated. Skip them.
527       AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
528       bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
529         return !llvm::isa<AffineDimExpr>(exp);
530       });
531       if (!srcEnc || hasCompExpr)
532         continue;
533 
534       // Try scheduling loop without constraints from `tval`.
535       AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
536       if (!order) // still cyclic
537         continue;
538 
539       // Found an input tensor that resolves the cycle by inserting a
540       // conversion into a sparse tensor that adheres to the iteration
541       // graph order.
542       auto stt = getSparseTensorType(tval);
543       assert(stt.isIdentity());
544       order = inversePermutation(order);
545       // sorted loop -> lvl map.
546       idxMap = idxMap.compose(order);
547 
548       // Found a permutation such that the results in `idxMap` is sorted.
549       // For example,
550       //  (d0, d1, d2, d3) -> (d2, d1, d0)
551       // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
552       // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
553       // transposed tensor's levels are visited in the same order as the loop
554       // scheduling order.
555       SmallVector<std::pair<unsigned, unsigned>> lvlSeq;
556       for (AffineExpr expr : idxMap.getResults()) {
557         unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
558         lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
559       }
560       std::sort(lvlSeq.begin(), lvlSeq.end(), [](auto &lhs, auto &rhs) -> bool {
561         return lhs.first < rhs.first;
562       });
563       SmallVector<unsigned> perm =
564           llvm::to_vector(llvm::make_second_range(lvlSeq));
565       auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
566       // The result of the idxMap must be unsorted.
567       assert(!dimToLvl.isIdentity());
568 
569       // Inserting the transpose
570       rewriter.setInsertionPoint(linalgOp);
571       RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
572       Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
573       rewriter.modifyOpInPlace(linalgOp, [&]() {
574         linalgOp->setOperand(t->getOperandNumber(), dst);
575       });
576       return success();
577     }
578     // Cannot be resolved with a single conversion.
579     // TODO: convert more than one?
580     return failure();
581   }
582 };
583 
584 //===----------------------------------------------------------------------===//
585 // Reinterpret Map Rewriters for operations other than linalg.generics
586 //===----------------------------------------------------------------------===//
587 
588 template <typename AllocOp>
589 struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
590   using OpRewritePattern<AllocOp>::OpRewritePattern;
591   LogicalResult matchAndRewrite(AllocOp op,
592                                 PatternRewriter &rewriter) const override {
593     if (!hasAnyNonIdentityOperandsOrResults(op))
594       return failure();
595 
596     Location loc = op.getLoc();
597     auto stt = getSparseTensorType(op.getResult());
598 
599     SmallVector<Value> maxDimCrds;
600     maxDimCrds.reserve(stt.getDimRank());
601     ValueRange dynSz = op.getDynamicSizes();
602     for (int64_t dimSz : stt.getDimShape()) {
603       if (ShapedType::isDynamic(dimSz)) {
604         Value maxCrd = rewriter.create<arith::SubIOp>(
605             loc, dynSz.front(), constantIndex(rewriter, loc, 1));
606         maxDimCrds.push_back(maxCrd);
607         dynSz = dynSz.drop_front();
608       } else {
609         maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
610       }
611     }
612 
613     ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
614                                               CrdTransDirectionKind::dim2lvl);
615     auto lvlShape = stt.getLvlShape();
616     SmallVector<Value> dynLvlSzs;
617     for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
618       if (ShapedType::isDynamic(lvlShape[i])) {
619         Value sz = rewriter.create<arith::AddIOp>(
620             loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
621         dynLvlSzs.push_back(sz);
622       }
623     }
624 
625     assert(dynSz.empty()); // should have consumed all.
626     rewriter.startOpModification(op);
627     op->setOperands(dynLvlSzs);
628     op.getResult().setType(stt.getDemappedType());
629     rewriter.finalizeOpModification(op);
630     rewriter.setInsertionPointAfter(op);
631 
632     Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
633     rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
634     return success();
635   }
636 };
637 
638 struct TensorInsertDemapper
639     : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
640   using DemapInsRewriter::DemapInsRewriter;
641   LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
642                           PatternRewriter &rewriter) const {
643     if (!hasAnySparseResult(op) || !hasAnyNonIdentityOperandsOrResults(op))
644       return failure();
645 
646     Location loc = op.getLoc();
647     auto stt = getSparseTensorType(op.getResult());
648     ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
649                                           CrdTransDirectionKind::dim2lvl);
650     auto insertOp = rewriter.create<tensor::InsertOp>(
651         loc, op.getScalar(), adaptor.getDest(), lvlCrd);
652 
653     Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
654     rewriter.replaceOp(op, out);
655     return success();
656   }
657 };
658 
659 struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {
660   using OpRewritePattern::OpRewritePattern;
661   LogicalResult matchAndRewrite(AssembleOp op,
662                                 PatternRewriter &rewriter) const override {
663     if (!hasAnyNonIdentityOperandsOrResults(op))
664       return failure();
665 
666     assert(hasAnySparseResult(op));
667     auto stt = getSparseTensorType(op.getResult());
668     rewriter.modifyOpInPlace(
669         op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
670     rewriter.setInsertionPointAfter(op);
671     Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
672     rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp());
673     return success();
674   }
675 };
676 
677 struct SparseDisassembleDemapper
678     : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
679   using DemapInsRewriter::DemapInsRewriter;
680   LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
681                           PatternRewriter &rewriter) const {
682     if (!hasAnyNonIdentityOperandsOrResults(op))
683       return failure();
684 
685     assert(hasAnySparseOperandOrResult(op));
686     rewriter.modifyOpInPlace(op, [&op, &adaptor]() {
687       op.getTensorMutable().assign(adaptor.getTensor());
688     });
689     return success();
690   }
691 };
692 
693 struct ForeachOpDemapper
694     : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
695   using DemapInsRewriter::DemapInsRewriter;
696   LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
697                           PatternRewriter &rewriter) const {
698     // Only handle operations with sparse input/output with non-identity dim2lvl
699     // maps.
700     if (!hasAnyNonIdentityOperandsOrResults(op))
701       return failure();
702 
703     // TODO: demap constant as well.
704     if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
705       if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
706         return failure();
707 
708     Location loc = op.getLoc();
709     // Cache the type information since we update the foreach op in-place.
710     auto srcStt = getSparseTensorType(op.getTensor());
711     SmallVector<Type> prevRetTps(op.getResultTypes());
712 
713     rewriter.startOpModification(op);
714     op.getTensorMutable().assign(adaptor.getTensor());
715     op.getInitArgsMutable().assign(adaptor.getInitArgs());
716     // Update results' types.
717     for (auto r : op.getResults())
718       if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
719         r.setType(stt->getDemappedType());
720 
721     Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
722     // Update the foreach body.
723     SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
724     blockArgTps.push_back(srcStt.getElementType());
725     blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
726                        adaptor.getInitArgs().getTypes().end());
727     Block *body = op.getBody();
728     // Block Args: [dimCrd, val, initArgs]
729     unsigned preArgNum = body->getNumArguments();
730     for (Type t : blockArgTps)
731       body->addArgument(t, loc);
732 
733     // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
734     rewriter.setInsertionPointToStart(body);
735     ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
736 
737     ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
738                                               CrdTransDirectionKind::lvl2dim);
739     rewriter.replaceAllUsesWith(
740         body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
741     body->eraseArguments(0, srcStt.getDimRank());
742     // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
743     unsigned numInitArgs = op.getInitArgs().size();
744     rewriter.replaceAllUsesWith(body->getArgument(0),
745                                 body->getArgument(lvlRank + numInitArgs + 1));
746     body->eraseArgument(0);
747     // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
748     ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
749     ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
750     // Remap back before replacement.
751     SmallVector<Value> reMappedArgs =
752         remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
753     rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
754     body->eraseArguments(0, numInitArgs);
755     // Block Args: [lvlCrds, DemappedArgs] and we are done.
756 
757     // Update yield operations.
758     if (numInitArgs != 0) {
759       rewriter.setInsertionPointToEnd(body);
760       auto yield = llvm::cast<YieldOp>(body->getTerminator());
761       if (auto stt = tryGetSparseTensorType(yield.getResult());
762           stt && !stt->isIdentity()) {
763         Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult());
764         rewriter.create<YieldOp>(loc, y);
765         rewriter.eraseOp(yield);
766       }
767     }
768     rewriter.finalizeOpModification(op);
769 
770     rewriter.setInsertionPointAfter(op);
771     SmallVector<Value> outs =
772         remapValueRange(rewriter, prevRetTps, op.getResults());
773 
774     // Replace all the uses of the foreach results, expect the use in
775     // reinterpret_map used to remap the output.
776     for (auto [from, to] : llvm::zip(op.getResults(), outs))
777       rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
778 
779     return success();
780   }
781 };
782 
783 } // namespace
784 
785 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
786                                         ReinterpretMapScope scope) {
787   if (scope == ReinterpretMapScope::kAll ||
788       scope == ReinterpretMapScope::kGenericOnly) {
789     patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
790         patterns.getContext());
791   }
792   if (scope == ReinterpretMapScope::kAll ||
793       scope == ReinterpretMapScope::kExceptGeneric) {
794     patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
795                  TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
796                  SparseDisassembleDemapper, TensorInsertDemapper,
797                  ForeachOpDemapper>(patterns.getContext());
798   }
799 }
800