xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (revision 197c3a3efc703711ac8f14bc4f1765eaadb8e5bc)
1 //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Utils/CodegenUtils.h"
10 #include "Utils/IterationGraphSorter.h"
11 
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Utils/Utils.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/AffineExprVisitor.h"
21 #include "mlir/IR/AffineMap.h"
22 
23 using namespace mlir;
24 using namespace mlir::sparse_tensor;
25 
26 namespace {
27 
28 //===----------------------------------------------------------------------===//
29 // File Local Helper classes.
30 //===----------------------------------------------------------------------===//
31 
32 // CRTP to help implementing a rewriter that demaps all its inputs.
33 template <typename SubClass, typename SourceOp>
34 struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
35   using OpRewritePattern<SourceOp>::OpRewritePattern;
36   using OpAdaptor = typename SourceOp::Adaptor;
37 
matchAndRewrite__anonbf08e2170111::DemapInsRewriter38   LogicalResult matchAndRewrite(SourceOp op,
39                                 PatternRewriter &rewriter) const override {
40     Location loc = op.getLoc();
41 
42     // Demaps non-trivial inputs.
43     bool changed = false;
44     SmallVector<Value> deMappedIns(op->getOperands());
45     for (Value &in : deMappedIns) {
46       if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
47         in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
48         changed = true;
49       }
50     }
51 
52     // CRTP call.
53     OpAdaptor adaptor(deMappedIns, op);
54     LogicalResult status =
55         static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
56     return changed ? success() : status;
57   }
58 };
59 
60 // Flattens an affine expression into a list of AffineDimExprs.
61 struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
AffineDimCollector__anonbf08e2170111::AffineDimCollector62   explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
visitDimExpr__anonbf08e2170111::AffineDimCollector63   void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
64   BitVector dims;
65 };
66 
67 // Flattens an affine expression into a list of AffineDimExprs.
68 struct AffineExprAdmissibleVisitor
69     : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
AffineExprAdmissibleVisitor__anonbf08e2170111::AffineExprAdmissibleVisitor70   explicit AffineExprAdmissibleVisitor(bool isOutput)
71       : admissible(true), isOutput(isOutput){};
72 
73   // We only allow AffineDimExpr on output.
visitAddExpr__anonbf08e2170111::AffineExprAdmissibleVisitor74   void visitAddExpr(AffineBinaryOpExpr expr) {
75     if (isOutput)
76       admissible = false;
77   }
visitMulExpr__anonbf08e2170111::AffineExprAdmissibleVisitor78   void visitMulExpr(AffineBinaryOpExpr expr) {
79     if (isOutput)
80       admissible = false;
81   }
82 
83   // We disallow mod, floor div and ceil div  on inputs.
visitModExpr__anonbf08e2170111::AffineExprAdmissibleVisitor84   void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
visitFloorDivExpr__anonbf08e2170111::AffineExprAdmissibleVisitor85   void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
visitCeilDivExpr__anonbf08e2170111::AffineExprAdmissibleVisitor86   void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
operator bool__anonbf08e2170111::AffineExprAdmissibleVisitor87   operator bool() { return admissible; }
88 
89 private:
90   bool admissible;
91   bool isOutput;
92 };
93 
94 // The first BitVector stores levels where inadmissible exprs are used.
95 // The second BitVector stores the AffineDimExp that are used by the
96 // inadmissible expressions.
97 using InadmissInfo = std::pair<BitVector, BitVector>;
98 
99 } // namespace
100 
101 //===----------------------------------------------------------------------===//
102 // File Local Helper methods.
103 //===----------------------------------------------------------------------===//
104 
105 // Collects the inadmissible affine expression imposed on levels.
collectInadmissInfo(AffineMap map,bool isOutput)106 static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
107   auto ret = std::make_pair(BitVector(map.getNumResults()),
108                             BitVector(map.getNumDims()));
109   AffineDimCollector collector(map.getNumDims());
110   for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
111     AffineExprAdmissibleVisitor admissible(isOutput);
112     admissible.walkPostOrder(map.getResult(lvl));
113     if (!admissible) {
114       // Record the inadmissible level.
115       ret.first.set(lvl);
116       // Record the AffineDimExpr that is used in the inadmissible expr.
117       collector.walkPostOrder(map.getResult(lvl));
118     }
119   }
120   ret.second = collector.dims;
121   return ret;
122 }
123 
124 // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
125 // inadmissible affine expressions can be eliminated.
126 // For example, we can rewrite
127 // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
128 // to
129 // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
130 // by composing inverse(idxMap), that is
131 // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
132 //                         -> ((l0 * 2 + l2) floordiv 2,
133 //                             (l1 * 3 + l3) floordiv 3,
134 //                             (l0 * 2 + l2) mod 2,
135 //                             (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
136 //
137 // This function builds the inverse(idxMap) that replace every dimensions used
138 // in `info` to levels, and updates the iterator type array `itTps` for the new
139 // index variable introduced.
140 //
141 // Note that the returned affine map does not retain the order of the input
142 // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
143 // replaced levels, and remaining ones for unused dimensions.
144 // For example, to handle
145 // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
146 // which is a typical map for block_2to4. The function returns:
147 // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
148 // in which, (l0, l1) together replaces `d1`, yet they appear
149 // before `d0` in the resulting affine map.
150 // The index (loop) order can later be canonicalized by a topo sort.
151 static AffineMap
genReplaceDimToLvlMap(const InadmissInfo & info,AffineMap idxMap,SmallVector<utils::IteratorType> & itTps)152 genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
153                       SmallVector<utils::IteratorType> &itTps) {
154   MLIRContext *ctx = idxMap.getContext();
155   auto [inAdLvls, usedDims] = info;
156   // Note that idxMap does not equal to dim2Lvl map, it is computed by
157   // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
158   // ID map.
159   // TODO: we might fail here, in those case we should really return
160   // failure instead of assertion error.
161   auto lvl2Idx = inferLvlToDim(idxMap, ctx);
162 
163   assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
164   if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
165     // This could happen when some dimensions are projected.
166     // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
167     //   ==> lvl2Idx = (j, k) -> (j, k)
168     // In this case, we append the unused dimesion at the end.
169     //   ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
170     SmallVector<AffineExpr> results;
171     AffineDimCollector usedInLvl(idxMap.getNumDims());
172     for (auto e : idxMap.getResults())
173       usedInLvl.walkPostOrder(e);
174 
175     unsigned curUsedDimID = 0;
176     unsigned curUnusedDimID = lvl2Idx.getNumDims();
177 
178     BitVector unused = usedInLvl.dims.flip();
179     for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
180       if (unused.test(i))
181         results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
182       else
183         results.push_back(lvl2Idx.getResult(curUsedDimID++));
184     }
185     lvl2Idx =
186         AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
187   }
188   assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
189 
190   // We do not need to replace the DimExpr that is not used in inadmissible
191   // level expressions. We use the first inAdLvl.count() dim to represent the
192   // replaced level, the remainings are reserved for unchanged ones.
193   // Note that results from the inverse map computed previously does not follow
194   // the convention we used, and we need to fix the mismatch below.
195   unsigned curRepID = 0;
196   unsigned curOriID = inAdLvls.count();
197   SmallVector<AffineExpr> results;
198   SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr());
199   SmallVector<utils::IteratorType> transItTps;
200 
201   for (unsigned l : inAdLvls.set_bits()) {
202     // By our convention, the inadmissible level `l` always appears in the
203     // leading part (accumulated by curRepID) of the affine map's parameter
204     // list. Record the mapping so that we can replace all the uses of `l` to
205     // the correct position after the translation.
206     dimRep[l] = getAffineDimExpr(curRepID++, ctx);
207     // A new index variable is introduced for the inadmissible level, inherit
208     // the iterator type. E.g., if l0 = d0 floordiv 2, the
209     // iterator type of l0 equals to the iterator type of d0.
210     AffineExpr lvlExp = idxMap.getResult(l);
211     AffineDimCollector collector(idxMap.getNumDims());
212     collector.walkPostOrder(lvlExp);
213     // We assumes a level can only be derived from one dimension.
214     assert(collector.dims.count() == 1);
215     transItTps.push_back(itTps[collector.dims.find_first()]);
216   }
217 
218   for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
219     if (usedDims.test(d)) {
220       // The dimension is used in some of the inadmissible levels, and it need
221       // to be inversed. Get the inversion from the inverse map, and fix the
222       // mismatch captured by the above loop.
223       results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
224     } else {
225       // The dimension is not used in any of the inadmissible levels, and it
226       // does not need to be inversed. Fix the mismatch by mapping it to the
227       // trailing part of the affine map (accumulated by curOriID).
228       results.push_back(getAffineDimExpr(curOriID++, ctx));
229       transItTps.push_back(itTps[d]);
230     }
231   }
232   unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
233   // Update iterator type.
234   itTps.assign(transItTps.begin(), transItTps.end());
235   return AffineMap::get(numDim, 0, results, ctx);
236 }
237 
238 // Translates the index map in the linalg::GenericOp from idx->dim map to
239 // idx->lvl map. Returns failure if the index map can not be translated to an
240 // admissible form.
241 // Returns the translated index map array and the iterator type array.
242 static std::optional<std::pair<ArrayAttr, ArrayAttr>>
translateMap(linalg::GenericOp op,PatternRewriter & rewriter)243 translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
244   // idxMap is a idx2dim map before reinterpretation.
245   MLIRContext *ctx = op.getContext();
246   SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
247   SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
248   for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
249     Value tensor = op->getOpOperand(i).get();
250     auto stt = tryGetSparseTensorType(tensor);
251     if (stt && !stt->isIdentity()) {
252       AffineMap dim2Lvl = stt->getDimToLvl();
253       // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
254       idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
255     }
256   }
257 
258   // A naive way to handle common constant expressions that arise during dim2lvl
259   // translation.
260   auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
261                                   unsigned pos, int64_t lvlSz) {
262     if (!ShapedType::isDynamic(lvlSz)) {
263       auto c0 = getAffineConstantExpr(0, ctx);
264       auto lvlExp = getAffineDimExpr(pos, ctx);
265       auto szExp = getAffineConstantExpr(lvlSz, ctx);
266 
267       // lvl floordiv lvlSz = 0
268       auto divExp =
269           getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp);
270       cstMapping.try_emplace(divExp, c0);
271 
272       // lvl mod lvlSz = lvl
273       auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
274       cstMapping.try_emplace(modExp, lvlExp);
275     }
276   };
277 
278   unsigned boundedNum = 0;
279   // A fixed-point algorithm.
280   bool changed = true;
281   while (changed) {
282     changed = false;
283     for (OpOperand &operand : op->getOpOperands()) {
284       auto stt = tryGetSparseTensorType(operand.get());
285       // Skip on dense operands.
286       if (!stt || !stt->getEncoding())
287         continue;
288 
289       unsigned tid = operand.getOperandNumber();
290       bool isOutput = &operand == op.getDpsInitOperand(0);
291       AffineMap idxMap = idxMapArray[tid];
292       InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
293       auto [inAdLvls, dimExprs] = inAdInfo;
294       for (unsigned d : dimExprs.set_bits()) {
295         // The first `boundedNum` used in the AffineMap is introduced to
296         // resolve previous inadmissible expressions. We can not replace them
297         // as it might bring back the inadmissible expressions.
298         if (d < boundedNum)
299           return std::nullopt;
300       }
301 
302       if (inAdLvls.count() != 0) {
303         // Naive constant progagation, should be sufficient to handle block
304         // sparsity in our cases.
305         SmallVector<int64_t> lvlShape = stt->getLvlShape();
306         DenseMap<AffineExpr, AffineExpr> cstMapping;
307         unsigned position = 0;
308         for (unsigned lvl : inAdLvls.set_bits()) {
309           int64_t lvlSz = lvlShape[lvl];
310           populateCstMapping(cstMapping, position, lvlSz);
311           position++;
312         }
313 
314         AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
315         // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
316         // inadmissible expressions.
317         for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
318           AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
319           idxMapArray[tid] = transMap.replace(
320               cstMapping, /*numResultDims=*/transMap.getNumDims(),
321               /*numResultSyms=*/0);
322         }
323         changed = true;
324         boundedNum += inAdLvls.count();
325       }
326     }
327   };
328 
329   SmallVector<Attribute> iterAttr =
330       llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
331         return linalg::IteratorTypeAttr::get(ctx, itTp);
332       });
333 
334   return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
335                         rewriter.getArrayAttr(iterAttr));
336 }
337 
338 // Generates a "de"mapping reinterpretation of the map.
genDemap(OpBuilder & builder,SparseTensorEncodingAttr enc,Value val)339 static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
340                       Value val) {
341   return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
342                                           val);
343 }
344 
345 // Generates a "re"mapping reinterpretation of the map.
genRemap(OpBuilder & builder,SparseTensorEncodingAttr enc,Value val)346 static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
347                       Value val) {
348   return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
349 }
350 
remapValueRange(OpBuilder & rewriter,TypeRange types,ValueRange outs)351 static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
352                                           ValueRange outs) {
353   SmallVector<Value> ret(outs);
354   assert(outs.size() == types.size());
355   for (auto [r, t] : llvm::zip(ret, types))
356     if (r.getType() != t)
357       r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
358   return ret;
359 }
360 
361 namespace {
362 
363 //===----------------------------------------------------------------------===//
364 // Rewriting rules for linalg generic ops.
365 //===----------------------------------------------------------------------===//
366 
367 /// Sparse rewriting rule for the generic `linalg` operation.
368 struct GenericOpReinterpretMap
369     : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
370 public:
371   using DemapInsRewriter::DemapInsRewriter;
rewriteOp__anonbf08e2170411::GenericOpReinterpretMap372   LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
373                           PatternRewriter &rewriter) const {
374     // Only rewrite single output operations with pure (sparse) tensor
375     // semantics.
376     if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
377         !hasAnySparseOperandOrResult(linalgOp) ||
378         !hasAnyNonIdentityOperandsOrResults(linalgOp))
379       return failure();
380 
381     // Try translating the index map.
382     auto transMap = translateMap(linalgOp, rewriter);
383     if (!transMap)
384       return rewriter.notifyMatchFailure(
385           linalgOp, "the sparse kernel can not be sparsified.");
386 
387     // On success, replace update the linalg operands and maps in place.
388     Value res = linalgOp.getResult(0);
389     auto stt = tryGetSparseTensorType(res);
390     auto [idxMap, itTp] = *transMap;
391 
392     rewriter.startOpModification(linalgOp);
393     linalgOp.setIndexingMapsAttr(idxMap);
394     linalgOp.setIteratorTypesAttr(itTp);
395     // Use demapped arguments.
396     linalgOp.getInputsMutable().assign(adaptor.getInputs());
397     linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
398     res.setType(adaptor.getOutputs()[0].getType());
399     rewriter.finalizeOpModification(linalgOp);
400 
401     rewriter.setInsertionPointAfter(linalgOp);
402     if (stt && stt->hasEncoding()) {
403       Value t = genRemap(rewriter, stt->getEncoding(), res);
404       rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
405     }
406     return success();
407   }
408 };
409 
410 struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
411   using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anonbf08e2170411::GenericOpScheduler412   LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
413                                 PatternRewriter &rewriter) const override {
414     if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
415         hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
416         !hasAnySparseOperandOrResult(linalgOp)) {
417       return failure();
418     }
419 
420     const StringRef sorted = "sorted";
421     if (linalgOp->hasAttr(sorted))
422       return failure();
423 
424     auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
425     bool isAdmissible = false;
426     AffineMap order;
427     // A const list of all masks that we used for iteration graph
428     // computation. Must be ordered from more strict to less strict.
429     // Ideally (though might not be guaranteed), the earlier a constraint mask
430     // can be satisfied, the faster the generated kernel will be.
431     const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
432                            SortMask::kIncludeDenseInput,
433                            SortMask::kIncludeDenseOutput,
434                            SortMask::kSparseOnly};
435     for (const SortMask mask : allMasks) {
436       order = scheduler.sort(mask);
437       if (order) {
438         if (isAdmissibleOrder(linalgOp, order)) {
439           isAdmissible = true;
440           break;
441         }
442         // else try a set of less strict constraints.
443       }
444     }
445 
446     if (!order) {
447       // Cycles detected.
448       if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
449         return rewriter.notifyMatchFailure(
450             linalgOp, "the sparse kernel can not be scheduled: loop detected.");
451       }
452       return success();
453     }
454 
455     if (!isAdmissible) {
456       return rewriter.notifyMatchFailure(
457           linalgOp, "the sparse kernel can not be scheduled.");
458     }
459 
460     // Marks the GenericOp to avoid recursive matching.
461     rewriter.modifyOpInPlace(linalgOp, [&]() {
462       linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
463     });
464 
465     // Already sorted.
466     if (order.isIdentity())
467       return success();
468 
469     assert(order.isPermutation());
470     // `order` is orignial loop -> sorted loop map
471     ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
472     SmallVector<Attribute> curItTypes;
473     curItTypes.reserve(preItTypes.size());
474     for (AffineExpr expr : order.getResults()) {
475       unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
476       curItTypes.push_back(preItTypes[loopID]);
477     }
478 
479     // Inverse `order` to get sorted loop -> original loop map
480     order = inversePermutation(order);
481     SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
482     for (AffineMap &idxMap : idxMaps)
483       idxMap = idxMap.compose(order); // sorted loop -> lvl map
484 
485     rewriter.startOpModification(linalgOp);
486     linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
487     linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
488     rewriter.finalizeOpModification(linalgOp);
489 
490     return success();
491   }
492 
493 private:
494   /// Whether the loop order is admissible by sparsification.
isAdmissibleOrder__anonbf08e2170411::GenericOpScheduler495   static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
496     if (!hasAnySparseResult(linalgOp))
497       return true;
498 
499     OpOperand *lhs = linalgOp.getDpsInitOperand(0);
500     unsigned nest = 0;
501     const auto iteratorTypes = linalgOp.getIteratorTypesArray();
502     for (const AffineExpr l : order.getResults()) {
503       unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
504       auto itTp =
505           cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
506       if (linalg::isReductionIterator(itTp.getValue()))
507         break; // terminate at first reduction
508       nest++;
509     }
510     // Determine admissible dynamic insertion situations:
511     // (1) fully injective, since there are no reductions,
512     // (2) admissible 1-d expansion in innermost dimension.
513     return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
514   };
515 
516   // Last resort cycle resolution.
resolveCycle__anonbf08e2170411::GenericOpScheduler517   static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
518                                     linalg::LinalgOp linalgOp,
519                                     PatternRewriter &rewriter) {
520     // Compute topological sort while leaving out every sparse input tensor in
521     // succession until an acylic iteration graph results.
522     for (OpOperand *t : linalgOp.getDpsInputOperands()) {
523       Value tval = t->get();
524       auto srcEnc = getSparseTensorEncoding(tval.getType());
525       // The constraints introduced by compound index expression are
526       // complicated. Skip them.
527       AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
528       bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
529         return !llvm::isa<AffineDimExpr>(exp);
530       });
531       if (!srcEnc || hasCompExpr)
532         continue;
533 
534       // Try scheduling loop without constraints from `tval`.
535       AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
536       if (!order) // still cyclic
537         continue;
538 
539       // Found an input tensor that resolves the cycle by inserting a
540       // conversion into a sparse tensor that adheres to the iteration
541       // graph order.
542       auto stt = getSparseTensorType(tval);
543       assert(stt.isIdentity());
544       order = inversePermutation(order);
545       // sorted loop -> lvl map.
546       idxMap = idxMap.compose(order);
547 
548       // Found a permutation such that the results in `idxMap` is sorted.
549       // For example,
550       //  (d0, d1, d2, d3) -> (d2, d1, d0)
551       // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
552       // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
553       // transposed tensor's levels are visited in the same order as the loop
554       // scheduling order.
555       SmallVector<std::pair<unsigned, unsigned>> lvlSeq;
556       for (AffineExpr expr : idxMap.getResults()) {
557         unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
558         lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
559       }
560       llvm::sort(lvlSeq, llvm::less_first());
561       SmallVector<unsigned> perm =
562           llvm::to_vector(llvm::make_second_range(lvlSeq));
563       auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
564       // The result of the idxMap must be unsorted.
565       assert(!dimToLvl.isIdentity());
566 
567       // Inserting the transpose
568       rewriter.setInsertionPoint(linalgOp);
569       RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
570       Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
571       rewriter.modifyOpInPlace(linalgOp, [&]() {
572         linalgOp->setOperand(t->getOperandNumber(), dst);
573       });
574 
575       // Release the transposed form afterwards.
576       // TODO: CSE when used in more than one following op?
577       rewriter.setInsertionPointAfter(linalgOp);
578       rewriter.create<bufferization::DeallocTensorOp>(dst.getLoc(), dst);
579 
580       return success();
581     }
582     // Cannot be resolved with a single conversion.
583     // TODO: convert more than one?
584     return failure();
585   }
586 };
587 
588 //===----------------------------------------------------------------------===//
589 // Reinterpret Map Rewriters for operations other than linalg.generics
590 //===----------------------------------------------------------------------===//
591 
592 template <typename AllocOp>
593 struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
594   using OpRewritePattern<AllocOp>::OpRewritePattern;
matchAndRewrite__anonbf08e2170411::TensorAllocDemapper595   LogicalResult matchAndRewrite(AllocOp op,
596                                 PatternRewriter &rewriter) const override {
597     if (!hasAnyNonIdentityOperandsOrResults(op))
598       return failure();
599 
600     Location loc = op.getLoc();
601     auto stt = getSparseTensorType(op.getResult());
602 
603     SmallVector<Value> maxDimCrds;
604     maxDimCrds.reserve(stt.getDimRank());
605     ValueRange dynSz = op.getDynamicSizes();
606     for (int64_t dimSz : stt.getDimShape()) {
607       if (ShapedType::isDynamic(dimSz)) {
608         Value maxCrd = rewriter.create<arith::SubIOp>(
609             loc, dynSz.front(), constantIndex(rewriter, loc, 1));
610         maxDimCrds.push_back(maxCrd);
611         dynSz = dynSz.drop_front();
612       } else {
613         maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
614       }
615     }
616 
617     ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
618                                               CrdTransDirectionKind::dim2lvl);
619     auto lvlShape = stt.getLvlShape();
620     SmallVector<Value> dynLvlSzs;
621     for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
622       if (ShapedType::isDynamic(lvlShape[i])) {
623         Value sz = rewriter.create<arith::AddIOp>(
624             loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
625         dynLvlSzs.push_back(sz);
626       }
627     }
628 
629     assert(dynSz.empty()); // should have consumed all.
630     rewriter.startOpModification(op);
631     op->setOperands(dynLvlSzs);
632     op.getResult().setType(stt.getDemappedType());
633     rewriter.finalizeOpModification(op);
634     rewriter.setInsertionPointAfter(op);
635 
636     Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
637     rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
638     return success();
639   }
640 };
641 
642 struct TensorInsertDemapper
643     : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
644   using DemapInsRewriter::DemapInsRewriter;
rewriteOp__anonbf08e2170411::TensorInsertDemapper645   LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
646                           PatternRewriter &rewriter) const {
647     if (!hasAnySparseResult(op) || !hasAnyNonIdentityOperandsOrResults(op))
648       return failure();
649 
650     Location loc = op.getLoc();
651     auto stt = getSparseTensorType(op.getResult());
652     ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
653                                           CrdTransDirectionKind::dim2lvl);
654     auto insertOp = rewriter.create<tensor::InsertOp>(
655         loc, op.getScalar(), adaptor.getDest(), lvlCrd);
656 
657     Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
658     rewriter.replaceOp(op, out);
659     return success();
660   }
661 };
662 
663 struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {
664   using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anonbf08e2170411::SparseAssembleDemapper665   LogicalResult matchAndRewrite(AssembleOp op,
666                                 PatternRewriter &rewriter) const override {
667     if (!hasAnyNonIdentityOperandsOrResults(op))
668       return failure();
669 
670     assert(hasAnySparseResult(op));
671     auto stt = getSparseTensorType(op.getResult());
672     rewriter.modifyOpInPlace(
673         op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
674     rewriter.setInsertionPointAfter(op);
675     Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
676     rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp());
677     return success();
678   }
679 };
680 
681 struct SparseDisassembleDemapper
682     : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
683   using DemapInsRewriter::DemapInsRewriter;
rewriteOp__anonbf08e2170411::SparseDisassembleDemapper684   LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
685                           PatternRewriter &rewriter) const {
686     if (!hasAnyNonIdentityOperandsOrResults(op))
687       return failure();
688 
689     assert(hasAnySparseOperandOrResult(op));
690     rewriter.modifyOpInPlace(op, [&op, &adaptor]() {
691       op.getTensorMutable().assign(adaptor.getTensor());
692     });
693     return success();
694   }
695 };
696 
697 struct ForeachOpDemapper
698     : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
699   using DemapInsRewriter::DemapInsRewriter;
rewriteOp__anonbf08e2170411::ForeachOpDemapper700   LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
701                           PatternRewriter &rewriter) const {
702     // Only handle operations with sparse input/output with non-identity dim2lvl
703     // maps.
704     if (!hasAnyNonIdentityOperandsOrResults(op))
705       return failure();
706 
707     // TODO: demap constant as well.
708     if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
709       if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
710         return failure();
711 
712     Location loc = op.getLoc();
713     // Cache the type information since we update the foreach op in-place.
714     auto srcStt = getSparseTensorType(op.getTensor());
715     SmallVector<Type> prevRetTps(op.getResultTypes());
716 
717     rewriter.startOpModification(op);
718     op.getTensorMutable().assign(adaptor.getTensor());
719     op.getInitArgsMutable().assign(adaptor.getInitArgs());
720     // Update results' types.
721     for (auto r : op.getResults())
722       if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
723         r.setType(stt->getDemappedType());
724 
725     Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
726     // Update the foreach body.
727     SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
728     blockArgTps.push_back(srcStt.getElementType());
729     blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
730                        adaptor.getInitArgs().getTypes().end());
731     Block *body = op.getBody();
732     // Block Args: [dimCrd, val, initArgs]
733     unsigned preArgNum = body->getNumArguments();
734     for (Type t : blockArgTps)
735       body->addArgument(t, loc);
736 
737     // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
738     rewriter.setInsertionPointToStart(body);
739     ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
740 
741     ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
742                                               CrdTransDirectionKind::lvl2dim);
743     rewriter.replaceAllUsesWith(
744         body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
745     body->eraseArguments(0, srcStt.getDimRank());
746     // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
747     unsigned numInitArgs = op.getInitArgs().size();
748     rewriter.replaceAllUsesWith(body->getArgument(0),
749                                 body->getArgument(lvlRank + numInitArgs + 1));
750     body->eraseArgument(0);
751     // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
752     ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
753     ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
754     // Remap back before replacement.
755     SmallVector<Value> reMappedArgs =
756         remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
757     rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
758     body->eraseArguments(0, numInitArgs);
759     // Block Args: [lvlCrds, DemappedArgs] and we are done.
760 
761     // Update yield operations.
762     if (numInitArgs != 0) {
763       rewriter.setInsertionPointToEnd(body);
764       auto yield = llvm::cast<YieldOp>(body->getTerminator());
765       if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
766           stt && !stt->isIdentity()) {
767         Value y =
768             genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
769         rewriter.create<YieldOp>(loc, y);
770         rewriter.eraseOp(yield);
771       }
772     }
773     rewriter.finalizeOpModification(op);
774 
775     rewriter.setInsertionPointAfter(op);
776     SmallVector<Value> outs =
777         remapValueRange(rewriter, prevRetTps, op.getResults());
778 
779     // Replace all the uses of the foreach results, expect the use in
780     // reinterpret_map used to remap the output.
781     for (auto [from, to] : llvm::zip(op.getResults(), outs))
782       rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
783 
784     return success();
785   }
786 };
787 
788 } // namespace
789 
populateSparseReinterpretMap(RewritePatternSet & patterns,ReinterpretMapScope scope)790 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
791                                         ReinterpretMapScope scope) {
792   if (scope == ReinterpretMapScope::kAll ||
793       scope == ReinterpretMapScope::kGenericOnly) {
794     patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
795         patterns.getContext());
796   }
797   if (scope == ReinterpretMapScope::kAll ||
798       scope == ReinterpretMapScope::kExceptGeneric) {
799     patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
800                  TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
801                  SparseDisassembleDemapper, TensorInsertDemapper,
802                  ForeachOpDemapper>(patterns.getContext());
803   }
804 }
805