xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (revision c99951d4916e18c9191d6a25a4a4fb1b2243d4c4)
1 //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "CodegenUtils.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
17 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/IR/AffineExprVisitor.h"
20 #include "mlir/IR/AffineMap.h"
21 
22 using namespace mlir;
23 using namespace mlir::sparse_tensor;
24 
25 namespace {
26 
27 //===----------------------------------------------------------------------===//
28 // File Local Helper classes.
29 //===----------------------------------------------------------------------===//
30 
31 // CRTP to help implementing a rewriter that demaps all its inputs.
32 template <typename SubClass, typename SourceOp>
33 struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
34   using OpRewritePattern<SourceOp>::OpRewritePattern;
35   using OpAdaptor = typename SourceOp::Adaptor;
36 
37   LogicalResult matchAndRewrite(SourceOp op,
38                                 PatternRewriter &rewriter) const override {
39     Location loc = op.getLoc();
40     // Demaps non-trivial inputs.
41     SmallVector<Value> deMappedIns(op->getOperands());
42     for (Value &in : deMappedIns)
43       if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
44         in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
45 
46     // CRTP call.
47     OpAdaptor adaptor(deMappedIns, op);
48     return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
49                                                           rewriter);
50   }
51 };
52 
53 // Flattens an affine expression into a list of AffineDimExprs.
54 struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
55   explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
56   void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
57   BitVector dims;
58 };
59 
60 // Flattens an affine expression into a list of AffineDimExprs.
61 struct AffineExprAdmissibleVisitor
62     : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
63   explicit AffineExprAdmissibleVisitor(bool isOutput)
64       : admissible(true), isOutput(isOutput){};
65 
66   // We only allow AffineDimExpr on output.
67   void visitAddExpr(AffineBinaryOpExpr expr) {
68     if (isOutput)
69       admissible = false;
70   }
71   void visitMulExpr(AffineBinaryOpExpr expr) {
72     if (isOutput)
73       admissible = false;
74   }
75 
76   // We disallow mod, floor div and ceil div  on inputs.
77   void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
78   void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
79   void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
80   operator bool() { return admissible; }
81 
82 private:
83   bool admissible;
84   bool isOutput;
85 };
86 
87 // The first BitVector stores levels where inadmissible exprs are used.
88 // The second BitVector stores the AffineDimExp that are used by the
89 // inadmissible expressions.
90 using InadmissInfo = std::pair<BitVector, BitVector>;
91 
92 } // namespace
93 
94 //===----------------------------------------------------------------------===//
95 // File Local Helper methods.
96 //===----------------------------------------------------------------------===//
97 
98 // Collects the inadmissible affine expression imposed on levels.
99 static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
100   auto ret = std::make_pair(BitVector(map.getNumResults()),
101                             BitVector(map.getNumDims()));
102   AffineDimCollector collector(map.getNumDims());
103   for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
104     AffineExprAdmissibleVisitor admissible(isOutput);
105     admissible.walkPostOrder(map.getResult(lvl));
106     if (!admissible) {
107       // Record the inadmissible level.
108       ret.first.set(lvl);
109       // Record the AffineDimExpr that is used in the inadmissible expr.
110       collector.walkPostOrder(map.getResult(lvl));
111     }
112   }
113   ret.second = collector.dims;
114   return ret;
115 }
116 
117 // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
118 // inadmissible affine expressions can be eliminated.
119 // For example, we can rewrite
120 // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
121 // to
122 // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
123 // by composing inverse(idxMap), that is
124 // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
125 //                         -> ((l0 * 2 + l2) floordiv 2,
126 //                             (l1 * 3 + l3) floordiv 3,
127 //                             (l0 * 2 + l2) mod 2,
128 //                             (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
129 //
130 // This function builds the inverse(idxMap) that replace every dimensions used
131 // in `info` to levels, and updates the iterator type array `itTps` for the new
132 // index variable introduced.
133 //
134 // Note that the returned affine map does not retain the order of the input
135 // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
136 // replaced levels, and remaining ones for unused dimensions.
137 // For example, to handle
138 // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
139 // which is a typical map for block_2to4. The function returns:
140 // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
141 // in which, (l0, l1) together replaces `d1`, yet they appear
142 // before `d0` in the resulting affine map.
143 // The index (loop) order can later be canonicalized by a topo sort.
144 static AffineMap
145 genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
146                       SmallVector<utils::IteratorType> &itTps) {
147   MLIRContext *ctx = idxMap.getContext();
148   auto [inAdLvls, usedDims] = info;
149   // Note that idxMap does not equal to dim2Lvl map, it is computed by
150   // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
151   // ID map.
152   // TODO: we might fail here, in those case we should really return
153   // failure instead of assertion error.
154   auto lvl2Idx = inferLvlToDim(idxMap, ctx);
155 
156   assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
157   if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
158     // This could happen when some dimensions are projected.
159     // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
160     //   ==> lvl2Idx = (j, k) -> (j, k)
161     // In this case, we append the unused dimesion at the end.
162     //   ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
163     SmallVector<AffineExpr> results;
164     AffineDimCollector usedInLvl(idxMap.getNumDims());
165     for (auto e : idxMap.getResults())
166       usedInLvl.walkPostOrder(e);
167 
168     unsigned curUsedDimID = 0;
169     unsigned curUnusedDimID = lvl2Idx.getNumDims();
170 
171     BitVector unused = usedInLvl.dims.flip();
172     for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
173       if (unused.test(i))
174         results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
175       else
176         results.push_back(lvl2Idx.getResult(curUsedDimID++));
177     }
178     lvl2Idx =
179         AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
180   }
181   assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
182 
183   // We do not need to replace the DimExpr that is not used in inadmissible
184   // level expressions. We use the first inAdLvl.count() dim to represent the
185   // replaced level, the remainings are reserved for unchanged ones.
186   // Note that results from the inverse map computed previously does not follow
187   // the convention we used, and we need to fix the mismatch below.
188   unsigned curRepID = 0;
189   unsigned curOriID = inAdLvls.count();
190   SmallVector<AffineExpr> results;
191   SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr());
192   SmallVector<utils::IteratorType> transItTps;
193 
194   for (unsigned l : inAdLvls.set_bits()) {
195     // By our convention, the inadmissible level `l` always appears in the
196     // leading part (accumulated by curRepID) of the affine map's parameter
197     // list. Record the mapping so that we can replace all the uses of `l` to
198     // the correct position after the translation.
199     dimRep[l] = getAffineDimExpr(curRepID++, ctx);
200     // A new index variable is introduced for the inadmissible level, inherit
201     // the iterator type. E.g., if l0 = d0 floordiv 2, the
202     // iterator type of l0 equals to the iterator type of d0.
203     AffineExpr lvlExp = idxMap.getResult(l);
204     AffineDimCollector collector(idxMap.getNumDims());
205     collector.walkPostOrder(lvlExp);
206     // We assumes a level can only be derived from one dimension.
207     assert(collector.dims.count() == 1);
208     transItTps.push_back(itTps[collector.dims.find_first()]);
209   }
210 
211   for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
212     if (usedDims.test(d)) {
213       // The dimension is used in some of the inadmissible levels, and it need
214       // to be inversed. Get the inversion from the inverse map, and fix the
215       // mismatch captured by the above loop.
216       results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
217     } else {
218       // The dimension is not used in any of the inadmissible levels, and it
219       // does not need to be inversed. Fix the mismatch by mapping it to the
220       // trailing part of the affine map (accumulated by curOriID).
221       results.push_back(getAffineDimExpr(curOriID++, ctx));
222       transItTps.push_back(itTps[d]);
223     }
224   }
225   unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
226   // Update iterator type.
227   itTps.assign(transItTps.begin(), transItTps.end());
228   return AffineMap::get(numDim, 0, results, ctx);
229 }
230 
231 // Translates the index map in the linalg::GenericOp from idx->dim map to
232 // idx->lvl map. Returns failure if the index map can not be translated to an
233 // admissible form.
234 // Returns the translated index map array and the iterator type array.
235 static std::optional<std::pair<ArrayAttr, ArrayAttr>>
236 translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
237   // idxMap is a idx2dim map before reinterpretation.
238   MLIRContext *ctx = op.getContext();
239   SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
240   SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
241   for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
242     Value tensor = op->getOpOperand(i).get();
243     auto stt = tryGetSparseTensorType(tensor);
244     if (stt && !stt->isIdentity()) {
245       AffineMap dim2Lvl = stt->getDimToLvl();
246       // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
247       idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
248     }
249   }
250 
251   // A naive way to handle common constant expressions that arise during dim2lvl
252   // translation.
253   auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
254                                   unsigned pos, int64_t lvlSz) {
255     if (!ShapedType::isDynamic(lvlSz)) {
256       auto c0 = getAffineConstantExpr(0, ctx);
257       auto lvlExp = getAffineDimExpr(pos, ctx);
258       auto szExp = getAffineConstantExpr(lvlSz, ctx);
259 
260       // lvl floordiv lvlSz = 0
261       auto divExp =
262           getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp);
263       cstMapping.try_emplace(divExp, c0);
264 
265       // lvl mod lvlSz = lvl
266       auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
267       cstMapping.try_emplace(modExp, lvlExp);
268     }
269   };
270 
271   unsigned boundedNum = 0;
272   // A fixed-point algorithm.
273   bool changed = true;
274   while (changed) {
275     changed = false;
276     for (OpOperand &operand : op->getOpOperands()) {
277       auto stt = tryGetSparseTensorType(operand.get());
278       // Skip on dense operands.
279       if (!stt || !stt->getEncoding())
280         continue;
281 
282       unsigned tid = operand.getOperandNumber();
283       bool isOutput = &operand == op.getDpsInitOperand(0);
284       AffineMap idxMap = idxMapArray[tid];
285       InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
286       auto [inAdLvls, dimExprs] = inAdInfo;
287       for (unsigned d : dimExprs.set_bits()) {
288         // The first `boundedNum` used in the AffineMap is introduced to
289         // resolve previous inadmissible expressions. We can not replace them
290         // as it might bring back the inadmissible expressions.
291         if (d < boundedNum)
292           return std::nullopt;
293       }
294 
295       if (inAdLvls.count() != 0) {
296         // Naive constant progagation, should be sufficient to handle block
297         // sparsity in our cases.
298         SmallVector<int64_t> lvlShape = stt->getLvlShape();
299         DenseMap<AffineExpr, AffineExpr> cstMapping;
300         unsigned position = 0;
301         for (unsigned lvl : inAdLvls.set_bits()) {
302           int64_t lvlSz = lvlShape[lvl];
303           populateCstMapping(cstMapping, position, lvlSz);
304           position++;
305         }
306 
307         AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
308         // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
309         // inadmissible expressions.
310         for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
311           AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
312           idxMapArray[tid] = transMap.replace(
313               cstMapping, /*numResultDims=*/transMap.getNumDims(),
314               /*numResultSyms=*/0);
315         }
316         changed = true;
317         boundedNum += inAdLvls.count();
318       }
319     }
320   };
321 
322   SmallVector<Attribute> iterAttr =
323       llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
324         return linalg::IteratorTypeAttr::get(ctx, itTp);
325       });
326 
327   return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
328                         rewriter.getArrayAttr(iterAttr));
329 }
330 
331 // Generates a "de"mapping reinterpretation of the map.
332 static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
333                       Value val) {
334   return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
335                                           val);
336 }
337 
338 // Generates a "re"mapping reinterpretation of the map.
339 static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
340                       Value val) {
341   return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
342 }
343 
344 static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
345                                           ValueRange outs) {
346   SmallVector<Value> ret(outs);
347   assert(outs.size() == types.size());
348   for (auto [r, t] : llvm::zip(ret, types))
349     if (r.getType() != t)
350       r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
351   return ret;
352 }
353 
354 /// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
355 static bool hasNonIdentityOperandsOrResults(Operation *op) {
356   auto hasNonIdentityMap = [](Value v) {
357     auto stt = tryGetSparseTensorType(v);
358     return stt && !stt->isIdentity();
359   };
360 
361   return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
362          llvm::any_of(op->getResults(), hasNonIdentityMap);
363 }
364 
365 namespace {
366 
367 //===----------------------------------------------------------------------===//
368 // Rewriting rules for linalg generic ops.
369 //===----------------------------------------------------------------------===//
370 
371 /// Sparse rewriting rule for the generic `linalg` operation.
372 struct GenericOpReinterpretMap
373     : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
374 public:
375   using DemapInsRewriter::DemapInsRewriter;
376   LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
377                           PatternRewriter &rewriter) const {
378     // Only rewrite single output operations with pure (sparse) tensor
379     // semantics.
380     if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
381         !hasAnySparseOperandOrResult(linalgOp) ||
382         !hasNonIdentityOperandsOrResults(linalgOp))
383       return failure();
384 
385     // Try translating the index map.
386     auto transMap = translateMap(linalgOp, rewriter);
387     if (!transMap)
388       return rewriter.notifyMatchFailure(
389           linalgOp, "the sparse kernel can not be sparsified.");
390 
391     // On success, replace update the linalg operands and maps in place.
392     Value res = linalgOp.getResult(0);
393     auto stt = tryGetSparseTensorType(res);
394     auto [idxMap, itTp] = *transMap;
395 
396     rewriter.startRootUpdate(linalgOp);
397     linalgOp.setIndexingMapsAttr(idxMap);
398     linalgOp.setIteratorTypesAttr(itTp);
399     // Use demapped arguments.
400     linalgOp.getInputsMutable().assign(adaptor.getInputs());
401     linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
402     res.setType(adaptor.getOutputs()[0].getType());
403     rewriter.finalizeRootUpdate(linalgOp);
404 
405     rewriter.setInsertionPointAfter(linalgOp);
406     if (stt && stt->hasEncoding()) {
407       Value t = genRemap(rewriter, stt->getEncoding(), res);
408       rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
409     }
410     return success();
411   }
412 };
413 
414 //===----------------------------------------------------------------------===//
415 // Reinterpret Map Rewriters for operations other than linalg.generics
416 //===----------------------------------------------------------------------===//
417 
418 template <typename AllocOp>
419 struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
420   using OpRewritePattern<AllocOp>::OpRewritePattern;
421   LogicalResult matchAndRewrite(AllocOp op,
422                                 PatternRewriter &rewriter) const override {
423     if (!hasNonIdentityOperandsOrResults(op))
424       return failure();
425 
426     Location loc = op.getLoc();
427     auto stt = getSparseTensorType(op.getResult());
428 
429     SmallVector<Value> maxDimCrds;
430     maxDimCrds.reserve(stt.getDimRank());
431     ValueRange dynSz = op.getDynamicSizes();
432     for (int64_t dimSz : stt.getDimShape()) {
433       if (ShapedType::isDynamic(dimSz)) {
434         Value maxCrd = rewriter.create<arith::SubIOp>(
435             loc, dynSz.front(), constantIndex(rewriter, loc, 1));
436         maxDimCrds.push_back(maxCrd);
437         dynSz = dynSz.drop_front();
438       } else {
439         maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
440       }
441     }
442 
443     ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
444                                               CrdTransDirectionKind::dim2lvl);
445     auto lvlShape = stt.getLvlShape();
446     SmallVector<Value> dynLvlSzs;
447     for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
448       if (ShapedType::isDynamic(lvlShape[i])) {
449         Value sz = rewriter.create<arith::AddIOp>(
450             loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
451         dynLvlSzs.push_back(sz);
452       }
453     }
454 
455     assert(dynSz.empty()); // should have consumed all.
456     rewriter.startRootUpdate(op);
457     op->setOperands(dynLvlSzs);
458     op.getResult().setType(stt.getDemappedType());
459     rewriter.finalizeRootUpdate(op);
460     rewriter.setInsertionPointAfter(op);
461 
462     Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
463     rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
464     return success();
465   }
466 };
467 
468 struct TensorInsertDemapper
469     : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
470   using DemapInsRewriter::DemapInsRewriter;
471   LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
472                           PatternRewriter &rewriter) const {
473     if (!hasAnySparseResult(op))
474       return failure();
475 
476     Location loc = op.getLoc();
477     auto stt = getSparseTensorType(op.getResult());
478     ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
479                                           CrdTransDirectionKind::dim2lvl);
480     auto insertOp = rewriter.create<sparse_tensor::InsertOp>(
481         loc, op.getScalar(), adaptor.getDest(), lvlCrd);
482 
483     Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
484     rewriter.replaceOp(op, out);
485     return success();
486   }
487 };
488 
489 struct ForeachOpDemapper
490     : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
491   using DemapInsRewriter::DemapInsRewriter;
492   LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
493                           PatternRewriter &rewriter) const {
494     // Only handle operations with sparse input/output with non-identity dim2lvl
495     // maps.
496     if (!hasNonIdentityOperandsOrResults(op))
497       return failure();
498 
499     // TODO: demap constant as well.
500     if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
501       if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
502         return failure();
503 
504     Location loc = op.getLoc();
505     // Cache the type information since we update the foreach op in-place.
506     auto srcStt = getSparseTensorType(op.getTensor());
507     SmallVector<Type> prevRetTps(op.getResultTypes());
508 
509     rewriter.startRootUpdate(op);
510     op.getTensorMutable().assign(adaptor.getTensor());
511     op.getInitArgsMutable().assign(adaptor.getInitArgs());
512     // Update results' types.
513     for (auto r : op.getResults())
514       if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
515         r.setType(stt->getDemappedType());
516 
517     Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
518     // Update the foreach body.
519     SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
520     blockArgTps.push_back(srcStt.getElementType());
521     blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
522                        adaptor.getInitArgs().getTypes().end());
523     Block *body = op.getBody();
524     // Block Args: [dimCrd, val, initArgs]
525     unsigned preArgNum = body->getNumArguments();
526     for (Type t : blockArgTps)
527       body->addArgument(t, loc);
528 
529     // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
530     rewriter.setInsertionPointToStart(body);
531     ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
532 
533     ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
534                                               CrdTransDirectionKind::lvl2dim);
535     rewriter.replaceAllUsesWith(
536         body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
537     body->eraseArguments(0, srcStt.getDimRank());
538     // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
539     unsigned numInitArgs = op.getInitArgs().size();
540     rewriter.replaceAllUsesWith(body->getArgument(0),
541                                 body->getArgument(lvlRank + numInitArgs + 1));
542     body->eraseArgument(0);
543     // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
544     ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
545     ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
546     // Remap back before replacement.
547     SmallVector<Value> reMappedArgs =
548         remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
549     rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
550     body->eraseArguments(0, numInitArgs);
551     // Block Args: [lvlCrds, DemappedArgs] and we are done.
552 
553     // Update yield operations.
554     if (numInitArgs != 0) {
555       rewriter.setInsertionPointToEnd(body);
556       auto yield = llvm::cast<YieldOp>(body->getTerminator());
557       if (auto stt = tryGetSparseTensorType(yield.getResult());
558           stt && !stt->isIdentity()) {
559         Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult());
560         rewriter.create<YieldOp>(loc, y);
561         rewriter.eraseOp(yield);
562       }
563     }
564     rewriter.finalizeRootUpdate(op);
565 
566     rewriter.setInsertionPointAfter(op);
567     SmallVector<Value> outs =
568         remapValueRange(rewriter, prevRetTps, op.getResults());
569 
570     // Replace all the uses of the foreach results, expect the use in
571     // reinterpret_map used to remap the output.
572     for (auto [from, to] : llvm::zip(op.getResults(), outs))
573       rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
574 
575     return success();
576   }
577 };
578 
579 } // namespace
580 
581 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
582                                         ReinterpretMapScope scope) {
583   if (scope == ReinterpretMapScope::kAll ||
584       scope == ReinterpretMapScope::kGenericOnly) {
585     patterns.add<GenericOpReinterpretMap>(patterns.getContext());
586   }
587   if (scope == ReinterpretMapScope::kAll ||
588       scope == ReinterpretMapScope::kExceptGeneric) {
589     patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
590                  TensorAllocDemapper<tensor::EmptyOp>, TensorInsertDemapper,
591                  ForeachOpDemapper>(patterns.getContext());
592   }
593 }
594