xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (revision 197c3a3efc703711ac8f14bc4f1765eaadb8e5bc)
106a65ce5SPeiming Liu //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
27cfac1beSAart Bik //
37cfac1beSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47cfac1beSAart Bik // See https://llvm.org/LICENSE.txt for license information.
57cfac1beSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67cfac1beSAart Bik //
77cfac1beSAart Bik //===----------------------------------------------------------------------===//
87cfac1beSAart Bik 
9365777ecSAart Bik #include "Utils/CodegenUtils.h"
10365777ecSAart Bik #include "Utils/IterationGraphSorter.h"
11c0d78c42SPeiming Liu 
12ef100c22SPeiming Liu #include "mlir/Dialect/Affine/IR/AffineOps.h"
13c0d78c42SPeiming Liu #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14e5999787SAart Bik #include "mlir/Dialect/Linalg/IR/Linalg.h"
15e5999787SAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h"
167cfac1beSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
177cfac1beSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
187cfac1beSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19ef100c22SPeiming Liu #include "mlir/Dialect/Tensor/IR/Tensor.h"
20c99951d4SPeiming Liu #include "mlir/IR/AffineExprVisitor.h"
21ef100c22SPeiming Liu #include "mlir/IR/AffineMap.h"
22ef100c22SPeiming Liu 
23ef100c22SPeiming Liu using namespace mlir;
24ef100c22SPeiming Liu using namespace mlir::sparse_tensor;
257cfac1beSAart Bik 
26c99951d4SPeiming Liu namespace {
27c99951d4SPeiming Liu 
28c99951d4SPeiming Liu //===----------------------------------------------------------------------===//
29c99951d4SPeiming Liu // File Local Helper classes.
30c99951d4SPeiming Liu //===----------------------------------------------------------------------===//
31c99951d4SPeiming Liu 
32c99951d4SPeiming Liu // CRTP to help implementing a rewriter that demaps all its inputs.
33c99951d4SPeiming Liu template <typename SubClass, typename SourceOp>
34c99951d4SPeiming Liu struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
35c99951d4SPeiming Liu   using OpRewritePattern<SourceOp>::OpRewritePattern;
36c99951d4SPeiming Liu   using OpAdaptor = typename SourceOp::Adaptor;
37c99951d4SPeiming Liu 
matchAndRewrite__anonbf08e2170111::DemapInsRewriter38c99951d4SPeiming Liu   LogicalResult matchAndRewrite(SourceOp op,
39c99951d4SPeiming Liu                                 PatternRewriter &rewriter) const override {
40c99951d4SPeiming Liu     Location loc = op.getLoc();
41986287e7SMatthias Springer 
42c99951d4SPeiming Liu     // Demaps non-trivial inputs.
43986287e7SMatthias Springer     bool changed = false;
44c99951d4SPeiming Liu     SmallVector<Value> deMappedIns(op->getOperands());
45986287e7SMatthias Springer     for (Value &in : deMappedIns) {
46986287e7SMatthias Springer       if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
47c99951d4SPeiming Liu         in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
48986287e7SMatthias Springer         changed = true;
49986287e7SMatthias Springer       }
50986287e7SMatthias Springer     }
51c99951d4SPeiming Liu 
52c99951d4SPeiming Liu     // CRTP call.
53c99951d4SPeiming Liu     OpAdaptor adaptor(deMappedIns, op);
54986287e7SMatthias Springer     LogicalResult status =
55986287e7SMatthias Springer         static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
56986287e7SMatthias Springer     return changed ? success() : status;
57c99951d4SPeiming Liu   }
58c99951d4SPeiming Liu };
59c99951d4SPeiming Liu 
60c99951d4SPeiming Liu // Flattens an affine expression into a list of AffineDimExprs.
61c99951d4SPeiming Liu struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
AffineDimCollector__anonbf08e2170111::AffineDimCollector62c99951d4SPeiming Liu   explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
visitDimExpr__anonbf08e2170111::AffineDimCollector63c99951d4SPeiming Liu   void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
64c99951d4SPeiming Liu   BitVector dims;
65c99951d4SPeiming Liu };
66c99951d4SPeiming Liu 
67c99951d4SPeiming Liu // Flattens an affine expression into a list of AffineDimExprs.
68c99951d4SPeiming Liu struct AffineExprAdmissibleVisitor
69c99951d4SPeiming Liu     : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
AffineExprAdmissibleVisitor__anonbf08e2170111::AffineExprAdmissibleVisitor70c99951d4SPeiming Liu   explicit AffineExprAdmissibleVisitor(bool isOutput)
71c99951d4SPeiming Liu       : admissible(true), isOutput(isOutput){};
72c99951d4SPeiming Liu 
73c99951d4SPeiming Liu   // We only allow AffineDimExpr on output.
visitAddExpr__anonbf08e2170111::AffineExprAdmissibleVisitor74c99951d4SPeiming Liu   void visitAddExpr(AffineBinaryOpExpr expr) {
75c99951d4SPeiming Liu     if (isOutput)
76c99951d4SPeiming Liu       admissible = false;
77c99951d4SPeiming Liu   }
visitMulExpr__anonbf08e2170111::AffineExprAdmissibleVisitor78c99951d4SPeiming Liu   void visitMulExpr(AffineBinaryOpExpr expr) {
79c99951d4SPeiming Liu     if (isOutput)
80c99951d4SPeiming Liu       admissible = false;
81c99951d4SPeiming Liu   }
82c99951d4SPeiming Liu 
83c99951d4SPeiming Liu   // We disallow mod, floor div and ceil div  on inputs.
visitModExpr__anonbf08e2170111::AffineExprAdmissibleVisitor84c99951d4SPeiming Liu   void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
visitFloorDivExpr__anonbf08e2170111::AffineExprAdmissibleVisitor85c99951d4SPeiming Liu   void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
visitCeilDivExpr__anonbf08e2170111::AffineExprAdmissibleVisitor86c99951d4SPeiming Liu   void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
operator bool__anonbf08e2170111::AffineExprAdmissibleVisitor87c99951d4SPeiming Liu   operator bool() { return admissible; }
88c99951d4SPeiming Liu 
89c99951d4SPeiming Liu private:
90c99951d4SPeiming Liu   bool admissible;
91c99951d4SPeiming Liu   bool isOutput;
92c99951d4SPeiming Liu };
93c99951d4SPeiming Liu 
94c99951d4SPeiming Liu // The first BitVector stores levels where inadmissible exprs are used.
95c99951d4SPeiming Liu // The second BitVector stores the AffineDimExp that are used by the
96c99951d4SPeiming Liu // inadmissible expressions.
97c99951d4SPeiming Liu using InadmissInfo = std::pair<BitVector, BitVector>;
98c99951d4SPeiming Liu 
99c99951d4SPeiming Liu } // namespace
100c99951d4SPeiming Liu 
101e5999787SAart Bik //===----------------------------------------------------------------------===//
1023426d330SPeiming Liu // File Local Helper methods.
103e5999787SAart Bik //===----------------------------------------------------------------------===//
104e5999787SAart Bik 
105c99951d4SPeiming Liu // Collects the inadmissible affine expression imposed on levels.
collectInadmissInfo(AffineMap map,bool isOutput)106c99951d4SPeiming Liu static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
107c99951d4SPeiming Liu   auto ret = std::make_pair(BitVector(map.getNumResults()),
108c99951d4SPeiming Liu                             BitVector(map.getNumDims()));
109c99951d4SPeiming Liu   AffineDimCollector collector(map.getNumDims());
110c99951d4SPeiming Liu   for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
111c99951d4SPeiming Liu     AffineExprAdmissibleVisitor admissible(isOutput);
112c99951d4SPeiming Liu     admissible.walkPostOrder(map.getResult(lvl));
113c99951d4SPeiming Liu     if (!admissible) {
114c99951d4SPeiming Liu       // Record the inadmissible level.
115c99951d4SPeiming Liu       ret.first.set(lvl);
116c99951d4SPeiming Liu       // Record the AffineDimExpr that is used in the inadmissible expr.
117c99951d4SPeiming Liu       collector.walkPostOrder(map.getResult(lvl));
118e5999787SAart Bik     }
119c99951d4SPeiming Liu   }
120c99951d4SPeiming Liu   ret.second = collector.dims;
121c99951d4SPeiming Liu   return ret;
122c99951d4SPeiming Liu }
123c99951d4SPeiming Liu 
124c99951d4SPeiming Liu // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
125c99951d4SPeiming Liu // inadmissible affine expressions can be eliminated.
126c99951d4SPeiming Liu // For example, we can rewrite
127c99951d4SPeiming Liu // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
128c99951d4SPeiming Liu // to
129c99951d4SPeiming Liu // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
130c99951d4SPeiming Liu // by composing inverse(idxMap), that is
131c99951d4SPeiming Liu // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
132c99951d4SPeiming Liu //                         -> ((l0 * 2 + l2) floordiv 2,
133c99951d4SPeiming Liu //                             (l1 * 3 + l3) floordiv 3,
134c99951d4SPeiming Liu //                             (l0 * 2 + l2) mod 2,
135c99951d4SPeiming Liu //                             (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
136c99951d4SPeiming Liu //
137c99951d4SPeiming Liu // This function builds the inverse(idxMap) that replace every dimensions used
138c99951d4SPeiming Liu // in `info` to levels, and updates the iterator type array `itTps` for the new
139c99951d4SPeiming Liu // index variable introduced.
140c99951d4SPeiming Liu //
141c99951d4SPeiming Liu // Note that the returned affine map does not retain the order of the input
142c99951d4SPeiming Liu // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
143c99951d4SPeiming Liu // replaced levels, and remaining ones for unused dimensions.
144c99951d4SPeiming Liu // For example, to handle
145c99951d4SPeiming Liu // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
146c99951d4SPeiming Liu // which is a typical map for block_2to4. The function returns:
147c99951d4SPeiming Liu // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
148c99951d4SPeiming Liu // in which, (l0, l1) together replaces `d1`, yet they appear
149c99951d4SPeiming Liu // before `d0` in the resulting affine map.
150c99951d4SPeiming Liu // The index (loop) order can later be canonicalized by a topo sort.
151c99951d4SPeiming Liu static AffineMap
genReplaceDimToLvlMap(const InadmissInfo & info,AffineMap idxMap,SmallVector<utils::IteratorType> & itTps)152c99951d4SPeiming Liu genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
153c99951d4SPeiming Liu                       SmallVector<utils::IteratorType> &itTps) {
154c99951d4SPeiming Liu   MLIRContext *ctx = idxMap.getContext();
155c99951d4SPeiming Liu   auto [inAdLvls, usedDims] = info;
156c99951d4SPeiming Liu   // Note that idxMap does not equal to dim2Lvl map, it is computed by
157c99951d4SPeiming Liu   // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
158c99951d4SPeiming Liu   // ID map.
159c99951d4SPeiming Liu   // TODO: we might fail here, in those case we should really return
160c99951d4SPeiming Liu   // failure instead of assertion error.
161c99951d4SPeiming Liu   auto lvl2Idx = inferLvlToDim(idxMap, ctx);
162c99951d4SPeiming Liu 
163c99951d4SPeiming Liu   assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
164c99951d4SPeiming Liu   if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
165c99951d4SPeiming Liu     // This could happen when some dimensions are projected.
166c99951d4SPeiming Liu     // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
167c99951d4SPeiming Liu     //   ==> lvl2Idx = (j, k) -> (j, k)
168c99951d4SPeiming Liu     // In this case, we append the unused dimesion at the end.
169c99951d4SPeiming Liu     //   ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
170c99951d4SPeiming Liu     SmallVector<AffineExpr> results;
171c99951d4SPeiming Liu     AffineDimCollector usedInLvl(idxMap.getNumDims());
172c99951d4SPeiming Liu     for (auto e : idxMap.getResults())
173c99951d4SPeiming Liu       usedInLvl.walkPostOrder(e);
174c99951d4SPeiming Liu 
175c99951d4SPeiming Liu     unsigned curUsedDimID = 0;
176c99951d4SPeiming Liu     unsigned curUnusedDimID = lvl2Idx.getNumDims();
177c99951d4SPeiming Liu 
178c99951d4SPeiming Liu     BitVector unused = usedInLvl.dims.flip();
179c99951d4SPeiming Liu     for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
180c99951d4SPeiming Liu       if (unused.test(i))
181c99951d4SPeiming Liu         results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
182c99951d4SPeiming Liu       else
183c99951d4SPeiming Liu         results.push_back(lvl2Idx.getResult(curUsedDimID++));
184c99951d4SPeiming Liu     }
185c99951d4SPeiming Liu     lvl2Idx =
186c99951d4SPeiming Liu         AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
187c99951d4SPeiming Liu   }
188c99951d4SPeiming Liu   assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
189c99951d4SPeiming Liu 
190c99951d4SPeiming Liu   // We do not need to replace the DimExpr that is not used in inadmissible
191c99951d4SPeiming Liu   // level expressions. We use the first inAdLvl.count() dim to represent the
192c99951d4SPeiming Liu   // replaced level, the remainings are reserved for unchanged ones.
193c99951d4SPeiming Liu   // Note that results from the inverse map computed previously does not follow
194c99951d4SPeiming Liu   // the convention we used, and we need to fix the mismatch below.
195c99951d4SPeiming Liu   unsigned curRepID = 0;
196c99951d4SPeiming Liu   unsigned curOriID = inAdLvls.count();
197c99951d4SPeiming Liu   SmallVector<AffineExpr> results;
198c99951d4SPeiming Liu   SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr());
199c99951d4SPeiming Liu   SmallVector<utils::IteratorType> transItTps;
200c99951d4SPeiming Liu 
201c99951d4SPeiming Liu   for (unsigned l : inAdLvls.set_bits()) {
202c99951d4SPeiming Liu     // By our convention, the inadmissible level `l` always appears in the
203c99951d4SPeiming Liu     // leading part (accumulated by curRepID) of the affine map's parameter
204c99951d4SPeiming Liu     // list. Record the mapping so that we can replace all the uses of `l` to
205c99951d4SPeiming Liu     // the correct position after the translation.
206c99951d4SPeiming Liu     dimRep[l] = getAffineDimExpr(curRepID++, ctx);
207c99951d4SPeiming Liu     // A new index variable is introduced for the inadmissible level, inherit
208c99951d4SPeiming Liu     // the iterator type. E.g., if l0 = d0 floordiv 2, the
209c99951d4SPeiming Liu     // iterator type of l0 equals to the iterator type of d0.
210c99951d4SPeiming Liu     AffineExpr lvlExp = idxMap.getResult(l);
211c99951d4SPeiming Liu     AffineDimCollector collector(idxMap.getNumDims());
212c99951d4SPeiming Liu     collector.walkPostOrder(lvlExp);
213c99951d4SPeiming Liu     // We assumes a level can only be derived from one dimension.
214c99951d4SPeiming Liu     assert(collector.dims.count() == 1);
215c99951d4SPeiming Liu     transItTps.push_back(itTps[collector.dims.find_first()]);
216c99951d4SPeiming Liu   }
217c99951d4SPeiming Liu 
218c99951d4SPeiming Liu   for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
219c99951d4SPeiming Liu     if (usedDims.test(d)) {
220c99951d4SPeiming Liu       // The dimension is used in some of the inadmissible levels, and it need
221c99951d4SPeiming Liu       // to be inversed. Get the inversion from the inverse map, and fix the
222c99951d4SPeiming Liu       // mismatch captured by the above loop.
223c99951d4SPeiming Liu       results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
224c99951d4SPeiming Liu     } else {
225c99951d4SPeiming Liu       // The dimension is not used in any of the inadmissible levels, and it
226c99951d4SPeiming Liu       // does not need to be inversed. Fix the mismatch by mapping it to the
227c99951d4SPeiming Liu       // trailing part of the affine map (accumulated by curOriID).
228c99951d4SPeiming Liu       results.push_back(getAffineDimExpr(curOriID++, ctx));
229c99951d4SPeiming Liu       transItTps.push_back(itTps[d]);
230c99951d4SPeiming Liu     }
231c99951d4SPeiming Liu   }
232c99951d4SPeiming Liu   unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
233c99951d4SPeiming Liu   // Update iterator type.
234c99951d4SPeiming Liu   itTps.assign(transItTps.begin(), transItTps.end());
235c99951d4SPeiming Liu   return AffineMap::get(numDim, 0, results, ctx);
236c99951d4SPeiming Liu }
237c99951d4SPeiming Liu 
238c99951d4SPeiming Liu // Translates the index map in the linalg::GenericOp from idx->dim map to
239c99951d4SPeiming Liu // idx->lvl map. Returns failure if the index map can not be translated to an
240c99951d4SPeiming Liu // admissible form.
241c99951d4SPeiming Liu // Returns the translated index map array and the iterator type array.
242c99951d4SPeiming Liu static std::optional<std::pair<ArrayAttr, ArrayAttr>>
translateMap(linalg::GenericOp op,PatternRewriter & rewriter)243c99951d4SPeiming Liu translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
244c99951d4SPeiming Liu   // idxMap is a idx2dim map before reinterpretation.
245c99951d4SPeiming Liu   MLIRContext *ctx = op.getContext();
246c99951d4SPeiming Liu   SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
247c99951d4SPeiming Liu   SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
248c99951d4SPeiming Liu   for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
249c99951d4SPeiming Liu     Value tensor = op->getOpOperand(i).get();
250c99951d4SPeiming Liu     auto stt = tryGetSparseTensorType(tensor);
251c99951d4SPeiming Liu     if (stt && !stt->isIdentity()) {
252c99951d4SPeiming Liu       AffineMap dim2Lvl = stt->getDimToLvl();
253c99951d4SPeiming Liu       // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
254c99951d4SPeiming Liu       idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
255c99951d4SPeiming Liu     }
256c99951d4SPeiming Liu   }
257c99951d4SPeiming Liu 
258c99951d4SPeiming Liu   // A naive way to handle common constant expressions that arise during dim2lvl
259c99951d4SPeiming Liu   // translation.
260c99951d4SPeiming Liu   auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
261c99951d4SPeiming Liu                                   unsigned pos, int64_t lvlSz) {
262c99951d4SPeiming Liu     if (!ShapedType::isDynamic(lvlSz)) {
263c99951d4SPeiming Liu       auto c0 = getAffineConstantExpr(0, ctx);
264c99951d4SPeiming Liu       auto lvlExp = getAffineDimExpr(pos, ctx);
265c99951d4SPeiming Liu       auto szExp = getAffineConstantExpr(lvlSz, ctx);
266c99951d4SPeiming Liu 
267c99951d4SPeiming Liu       // lvl floordiv lvlSz = 0
268c99951d4SPeiming Liu       auto divExp =
269c99951d4SPeiming Liu           getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp);
270c99951d4SPeiming Liu       cstMapping.try_emplace(divExp, c0);
271c99951d4SPeiming Liu 
272c99951d4SPeiming Liu       // lvl mod lvlSz = lvl
273c99951d4SPeiming Liu       auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
274c99951d4SPeiming Liu       cstMapping.try_emplace(modExp, lvlExp);
275c99951d4SPeiming Liu     }
276c99951d4SPeiming Liu   };
277c99951d4SPeiming Liu 
278c99951d4SPeiming Liu   unsigned boundedNum = 0;
279c99951d4SPeiming Liu   // A fixed-point algorithm.
280c99951d4SPeiming Liu   bool changed = true;
281c99951d4SPeiming Liu   while (changed) {
282c99951d4SPeiming Liu     changed = false;
283c99951d4SPeiming Liu     for (OpOperand &operand : op->getOpOperands()) {
284c99951d4SPeiming Liu       auto stt = tryGetSparseTensorType(operand.get());
285c99951d4SPeiming Liu       // Skip on dense operands.
286c99951d4SPeiming Liu       if (!stt || !stt->getEncoding())
287c99951d4SPeiming Liu         continue;
288c99951d4SPeiming Liu 
289c99951d4SPeiming Liu       unsigned tid = operand.getOperandNumber();
290c99951d4SPeiming Liu       bool isOutput = &operand == op.getDpsInitOperand(0);
291c99951d4SPeiming Liu       AffineMap idxMap = idxMapArray[tid];
292c99951d4SPeiming Liu       InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
293c99951d4SPeiming Liu       auto [inAdLvls, dimExprs] = inAdInfo;
294c99951d4SPeiming Liu       for (unsigned d : dimExprs.set_bits()) {
295c99951d4SPeiming Liu         // The first `boundedNum` used in the AffineMap is introduced to
296c99951d4SPeiming Liu         // resolve previous inadmissible expressions. We can not replace them
297c99951d4SPeiming Liu         // as it might bring back the inadmissible expressions.
298c99951d4SPeiming Liu         if (d < boundedNum)
299c99951d4SPeiming Liu           return std::nullopt;
300c99951d4SPeiming Liu       }
301c99951d4SPeiming Liu 
302c99951d4SPeiming Liu       if (inAdLvls.count() != 0) {
303c99951d4SPeiming Liu         // Naive constant progagation, should be sufficient to handle block
304c99951d4SPeiming Liu         // sparsity in our cases.
305c99951d4SPeiming Liu         SmallVector<int64_t> lvlShape = stt->getLvlShape();
306c99951d4SPeiming Liu         DenseMap<AffineExpr, AffineExpr> cstMapping;
307c99951d4SPeiming Liu         unsigned position = 0;
308c99951d4SPeiming Liu         for (unsigned lvl : inAdLvls.set_bits()) {
309c99951d4SPeiming Liu           int64_t lvlSz = lvlShape[lvl];
310c99951d4SPeiming Liu           populateCstMapping(cstMapping, position, lvlSz);
311c99951d4SPeiming Liu           position++;
312c99951d4SPeiming Liu         }
313c99951d4SPeiming Liu 
314c99951d4SPeiming Liu         AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
315c99951d4SPeiming Liu         // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
316c99951d4SPeiming Liu         // inadmissible expressions.
317c99951d4SPeiming Liu         for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
318c99951d4SPeiming Liu           AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
319c99951d4SPeiming Liu           idxMapArray[tid] = transMap.replace(
320c99951d4SPeiming Liu               cstMapping, /*numResultDims=*/transMap.getNumDims(),
321c99951d4SPeiming Liu               /*numResultSyms=*/0);
322c99951d4SPeiming Liu         }
323c99951d4SPeiming Liu         changed = true;
324c99951d4SPeiming Liu         boundedNum += inAdLvls.count();
325c99951d4SPeiming Liu       }
326c99951d4SPeiming Liu     }
327c99951d4SPeiming Liu   };
328c99951d4SPeiming Liu 
329c99951d4SPeiming Liu   SmallVector<Attribute> iterAttr =
330c99951d4SPeiming Liu       llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
331c99951d4SPeiming Liu         return linalg::IteratorTypeAttr::get(ctx, itTp);
332c99951d4SPeiming Liu       });
333c99951d4SPeiming Liu 
334c99951d4SPeiming Liu   return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
335c99951d4SPeiming Liu                         rewriter.getArrayAttr(iterAttr));
336e5999787SAart Bik }
337e5999787SAart Bik 
338e5999787SAart Bik // Generates a "de"mapping reinterpretation of the map.
genDemap(OpBuilder & builder,SparseTensorEncodingAttr enc,Value val)339e5999787SAart Bik static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
340e5999787SAart Bik                       Value val) {
341e5999787SAart Bik   return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
342e5999787SAart Bik                                           val);
343e5999787SAart Bik }
344e5999787SAart Bik 
345e5999787SAart Bik // Generates a "re"mapping reinterpretation of the map.
genRemap(OpBuilder & builder,SparseTensorEncodingAttr enc,Value val)346e5999787SAart Bik static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
347e5999787SAart Bik                       Value val) {
348e5999787SAart Bik   return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
349e5999787SAart Bik }
350e5999787SAart Bik 
remapValueRange(OpBuilder & rewriter,TypeRange types,ValueRange outs)3513426d330SPeiming Liu static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
3523426d330SPeiming Liu                                           ValueRange outs) {
3533426d330SPeiming Liu   SmallVector<Value> ret(outs);
3543426d330SPeiming Liu   assert(outs.size() == types.size());
3553426d330SPeiming Liu   for (auto [r, t] : llvm::zip(ret, types))
3563426d330SPeiming Liu     if (r.getType() != t)
3573426d330SPeiming Liu       r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
3583426d330SPeiming Liu   return ret;
3593426d330SPeiming Liu }
3603426d330SPeiming Liu 
3613426d330SPeiming Liu namespace {
3623426d330SPeiming Liu 
363e5999787SAart Bik //===----------------------------------------------------------------------===//
364e5999787SAart Bik // Rewriting rules for linalg generic ops.
365e5999787SAart Bik //===----------------------------------------------------------------------===//
366e5999787SAart Bik 
367e5999787SAart Bik /// Sparse rewriting rule for the generic `linalg` operation.
368c99951d4SPeiming Liu struct GenericOpReinterpretMap
369c99951d4SPeiming Liu     : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
370e5999787SAart Bik public:
371c99951d4SPeiming Liu   using DemapInsRewriter::DemapInsRewriter;
rewriteOp__anonbf08e2170411::GenericOpReinterpretMap372c99951d4SPeiming Liu   LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
373c99951d4SPeiming Liu                           PatternRewriter &rewriter) const {
374c99951d4SPeiming Liu     // Only rewrite single output operations with pure (sparse) tensor
375c99951d4SPeiming Liu     // semantics.
3760a8e3dd4SMatthias Springer     if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
377c99951d4SPeiming Liu         !hasAnySparseOperandOrResult(linalgOp) ||
37806a65ce5SPeiming Liu         !hasAnyNonIdentityOperandsOrResults(linalgOp))
379c99951d4SPeiming Liu       return failure();
380e5999787SAart Bik 
381c99951d4SPeiming Liu     // Try translating the index map.
382c99951d4SPeiming Liu     auto transMap = translateMap(linalgOp, rewriter);
383c99951d4SPeiming Liu     if (!transMap)
384c99951d4SPeiming Liu       return rewriter.notifyMatchFailure(
385c99951d4SPeiming Liu           linalgOp, "the sparse kernel can not be sparsified.");
386c99951d4SPeiming Liu 
387c99951d4SPeiming Liu     // On success, replace update the linalg operands and maps in place.
388c99951d4SPeiming Liu     Value res = linalgOp.getResult(0);
389c99951d4SPeiming Liu     auto stt = tryGetSparseTensorType(res);
390c99951d4SPeiming Liu     auto [idxMap, itTp] = *transMap;
391c99951d4SPeiming Liu 
3925fcf907bSMatthias Springer     rewriter.startOpModification(linalgOp);
393c99951d4SPeiming Liu     linalgOp.setIndexingMapsAttr(idxMap);
394c99951d4SPeiming Liu     linalgOp.setIteratorTypesAttr(itTp);
395c99951d4SPeiming Liu     // Use demapped arguments.
396c99951d4SPeiming Liu     linalgOp.getInputsMutable().assign(adaptor.getInputs());
397c99951d4SPeiming Liu     linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
398c99951d4SPeiming Liu     res.setType(adaptor.getOutputs()[0].getType());
3995fcf907bSMatthias Springer     rewriter.finalizeOpModification(linalgOp);
400c99951d4SPeiming Liu 
401c99951d4SPeiming Liu     rewriter.setInsertionPointAfter(linalgOp);
402b19c40c5SAart Bik     if (stt && stt->hasEncoding()) {
403c99951d4SPeiming Liu       Value t = genRemap(rewriter, stt->getEncoding(), res);
404c99951d4SPeiming Liu       rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
405e5999787SAart Bik     }
406e5999787SAart Bik     return success();
407e5999787SAart Bik   }
408e5999787SAart Bik };
409e5999787SAart Bik 
41006a65ce5SPeiming Liu struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
41106a65ce5SPeiming Liu   using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anonbf08e2170411::GenericOpScheduler41206a65ce5SPeiming Liu   LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
41306a65ce5SPeiming Liu                                 PatternRewriter &rewriter) const override {
4140a8e3dd4SMatthias Springer     if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
41506a65ce5SPeiming Liu         hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
41606a65ce5SPeiming Liu         !hasAnySparseOperandOrResult(linalgOp)) {
41706a65ce5SPeiming Liu       return failure();
41806a65ce5SPeiming Liu     }
41906a65ce5SPeiming Liu 
42006a65ce5SPeiming Liu     const StringRef sorted = "sorted";
42106a65ce5SPeiming Liu     if (linalgOp->hasAttr(sorted))
42206a65ce5SPeiming Liu       return failure();
42306a65ce5SPeiming Liu 
42406a65ce5SPeiming Liu     auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
42506a65ce5SPeiming Liu     bool isAdmissible = false;
42606a65ce5SPeiming Liu     AffineMap order;
42706a65ce5SPeiming Liu     // A const list of all masks that we used for iteration graph
42806a65ce5SPeiming Liu     // computation. Must be ordered from more strict to less strict.
42906a65ce5SPeiming Liu     // Ideally (though might not be guaranteed), the earlier a constraint mask
43006a65ce5SPeiming Liu     // can be satisfied, the faster the generated kernel will be.
4314e2f1521SPeiming Liu     const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
4324e2f1521SPeiming Liu                            SortMask::kIncludeDenseInput,
4334e2f1521SPeiming Liu                            SortMask::kIncludeDenseOutput,
4344e2f1521SPeiming Liu                            SortMask::kSparseOnly};
43506a65ce5SPeiming Liu     for (const SortMask mask : allMasks) {
43606a65ce5SPeiming Liu       order = scheduler.sort(mask);
43706a65ce5SPeiming Liu       if (order) {
43806a65ce5SPeiming Liu         if (isAdmissibleOrder(linalgOp, order)) {
43906a65ce5SPeiming Liu           isAdmissible = true;
44006a65ce5SPeiming Liu           break;
44106a65ce5SPeiming Liu         }
44206a65ce5SPeiming Liu         // else try a set of less strict constraints.
44306a65ce5SPeiming Liu       }
44406a65ce5SPeiming Liu     }
44506a65ce5SPeiming Liu 
44606a65ce5SPeiming Liu     if (!order) {
44706a65ce5SPeiming Liu       // Cycles detected.
44806a65ce5SPeiming Liu       if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
44906a65ce5SPeiming Liu         return rewriter.notifyMatchFailure(
45006a65ce5SPeiming Liu             linalgOp, "the sparse kernel can not be scheduled: loop detected.");
45106a65ce5SPeiming Liu       }
45206a65ce5SPeiming Liu       return success();
45306a65ce5SPeiming Liu     }
45406a65ce5SPeiming Liu 
45506a65ce5SPeiming Liu     if (!isAdmissible) {
45606a65ce5SPeiming Liu       return rewriter.notifyMatchFailure(
45706a65ce5SPeiming Liu           linalgOp, "the sparse kernel can not be scheduled.");
45806a65ce5SPeiming Liu     }
45906a65ce5SPeiming Liu 
46006a65ce5SPeiming Liu     // Marks the GenericOp to avoid recursive matching.
4615fcf907bSMatthias Springer     rewriter.modifyOpInPlace(linalgOp, [&]() {
46206a65ce5SPeiming Liu       linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
463986287e7SMatthias Springer     });
46406a65ce5SPeiming Liu 
46506a65ce5SPeiming Liu     // Already sorted.
46606a65ce5SPeiming Liu     if (order.isIdentity())
467986287e7SMatthias Springer       return success();
46806a65ce5SPeiming Liu 
46906a65ce5SPeiming Liu     assert(order.isPermutation());
47006a65ce5SPeiming Liu     // `order` is orignial loop -> sorted loop map
47106a65ce5SPeiming Liu     ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
47206a65ce5SPeiming Liu     SmallVector<Attribute> curItTypes;
47306a65ce5SPeiming Liu     curItTypes.reserve(preItTypes.size());
47406a65ce5SPeiming Liu     for (AffineExpr expr : order.getResults()) {
47506a65ce5SPeiming Liu       unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
47606a65ce5SPeiming Liu       curItTypes.push_back(preItTypes[loopID]);
47706a65ce5SPeiming Liu     }
47806a65ce5SPeiming Liu 
47906a65ce5SPeiming Liu     // Inverse `order` to get sorted loop -> original loop map
48006a65ce5SPeiming Liu     order = inversePermutation(order);
48106a65ce5SPeiming Liu     SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
48206a65ce5SPeiming Liu     for (AffineMap &idxMap : idxMaps)
48306a65ce5SPeiming Liu       idxMap = idxMap.compose(order); // sorted loop -> lvl map
48406a65ce5SPeiming Liu 
4855fcf907bSMatthias Springer     rewriter.startOpModification(linalgOp);
48606a65ce5SPeiming Liu     linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
48706a65ce5SPeiming Liu     linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
4885fcf907bSMatthias Springer     rewriter.finalizeOpModification(linalgOp);
48906a65ce5SPeiming Liu 
49006a65ce5SPeiming Liu     return success();
49106a65ce5SPeiming Liu   }
49206a65ce5SPeiming Liu 
49306a65ce5SPeiming Liu private:
49406a65ce5SPeiming Liu   /// Whether the loop order is admissible by sparsification.
isAdmissibleOrder__anonbf08e2170411::GenericOpScheduler49506a65ce5SPeiming Liu   static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
49606a65ce5SPeiming Liu     if (!hasAnySparseResult(linalgOp))
49706a65ce5SPeiming Liu       return true;
49806a65ce5SPeiming Liu 
49906a65ce5SPeiming Liu     OpOperand *lhs = linalgOp.getDpsInitOperand(0);
50006a65ce5SPeiming Liu     unsigned nest = 0;
50106a65ce5SPeiming Liu     const auto iteratorTypes = linalgOp.getIteratorTypesArray();
50206a65ce5SPeiming Liu     for (const AffineExpr l : order.getResults()) {
50306a65ce5SPeiming Liu       unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
50406a65ce5SPeiming Liu       auto itTp =
505a5757c5bSChristian Sigg           cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
50606a65ce5SPeiming Liu       if (linalg::isReductionIterator(itTp.getValue()))
50706a65ce5SPeiming Liu         break; // terminate at first reduction
50806a65ce5SPeiming Liu       nest++;
50906a65ce5SPeiming Liu     }
51006a65ce5SPeiming Liu     // Determine admissible dynamic insertion situations:
51106a65ce5SPeiming Liu     // (1) fully injective, since there are no reductions,
51206a65ce5SPeiming Liu     // (2) admissible 1-d expansion in innermost dimension.
51306a65ce5SPeiming Liu     return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
51406a65ce5SPeiming Liu   };
51506a65ce5SPeiming Liu 
51606a65ce5SPeiming Liu   // Last resort cycle resolution.
resolveCycle__anonbf08e2170411::GenericOpScheduler51706a65ce5SPeiming Liu   static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
51806a65ce5SPeiming Liu                                     linalg::LinalgOp linalgOp,
51906a65ce5SPeiming Liu                                     PatternRewriter &rewriter) {
52006a65ce5SPeiming Liu     // Compute topological sort while leaving out every sparse input tensor in
52106a65ce5SPeiming Liu     // succession until an acylic iteration graph results.
52206a65ce5SPeiming Liu     for (OpOperand *t : linalgOp.getDpsInputOperands()) {
52306a65ce5SPeiming Liu       Value tval = t->get();
52406a65ce5SPeiming Liu       auto srcEnc = getSparseTensorEncoding(tval.getType());
52506a65ce5SPeiming Liu       // The constraints introduced by compound index expression are
52606a65ce5SPeiming Liu       // complicated. Skip them.
52706a65ce5SPeiming Liu       AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
52806a65ce5SPeiming Liu       bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
52906a65ce5SPeiming Liu         return !llvm::isa<AffineDimExpr>(exp);
53006a65ce5SPeiming Liu       });
53106a65ce5SPeiming Liu       if (!srcEnc || hasCompExpr)
53206a65ce5SPeiming Liu         continue;
53306a65ce5SPeiming Liu 
53406a65ce5SPeiming Liu       // Try scheduling loop without constraints from `tval`.
53506a65ce5SPeiming Liu       AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
53606a65ce5SPeiming Liu       if (!order) // still cyclic
53706a65ce5SPeiming Liu         continue;
53806a65ce5SPeiming Liu 
53906a65ce5SPeiming Liu       // Found an input tensor that resolves the cycle by inserting a
54006a65ce5SPeiming Liu       // conversion into a sparse tensor that adheres to the iteration
54106a65ce5SPeiming Liu       // graph order.
54206a65ce5SPeiming Liu       auto stt = getSparseTensorType(tval);
54306a65ce5SPeiming Liu       assert(stt.isIdentity());
54406a65ce5SPeiming Liu       order = inversePermutation(order);
54506a65ce5SPeiming Liu       // sorted loop -> lvl map.
54606a65ce5SPeiming Liu       idxMap = idxMap.compose(order);
54706a65ce5SPeiming Liu 
54806a65ce5SPeiming Liu       // Found a permutation such that the results in `idxMap` is sorted.
54906a65ce5SPeiming Liu       // For example,
55006a65ce5SPeiming Liu       //  (d0, d1, d2, d3) -> (d2, d1, d0)
55106a65ce5SPeiming Liu       // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
55206a65ce5SPeiming Liu       // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
55306a65ce5SPeiming Liu       // transposed tensor's levels are visited in the same order as the loop
55406a65ce5SPeiming Liu       // scheduling order.
55506a65ce5SPeiming Liu       SmallVector<std::pair<unsigned, unsigned>> lvlSeq;
55606a65ce5SPeiming Liu       for (AffineExpr expr : idxMap.getResults()) {
55706a65ce5SPeiming Liu         unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
55806a65ce5SPeiming Liu         lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
55906a65ce5SPeiming Liu       }
560*197c3a3eSKazu Hirata       llvm::sort(lvlSeq, llvm::less_first());
56106a65ce5SPeiming Liu       SmallVector<unsigned> perm =
56206a65ce5SPeiming Liu           llvm::to_vector(llvm::make_second_range(lvlSeq));
56306a65ce5SPeiming Liu       auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
56406a65ce5SPeiming Liu       // The result of the idxMap must be unsorted.
56506a65ce5SPeiming Liu       assert(!dimToLvl.isIdentity());
56606a65ce5SPeiming Liu 
56706a65ce5SPeiming Liu       // Inserting the transpose
56806a65ce5SPeiming Liu       rewriter.setInsertionPoint(linalgOp);
56906a65ce5SPeiming Liu       RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
57006a65ce5SPeiming Liu       Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
5715fcf907bSMatthias Springer       rewriter.modifyOpInPlace(linalgOp, [&]() {
57206a65ce5SPeiming Liu         linalgOp->setOperand(t->getOperandNumber(), dst);
57306a65ce5SPeiming Liu       });
57442c38b1cSAart Bik 
57542c38b1cSAart Bik       // Release the transposed form afterwards.
57642c38b1cSAart Bik       // TODO: CSE when used in more than one following op?
57742c38b1cSAart Bik       rewriter.setInsertionPointAfter(linalgOp);
57842c38b1cSAart Bik       rewriter.create<bufferization::DeallocTensorOp>(dst.getLoc(), dst);
57942c38b1cSAart Bik 
58006a65ce5SPeiming Liu       return success();
58106a65ce5SPeiming Liu     }
58206a65ce5SPeiming Liu     // Cannot be resolved with a single conversion.
58306a65ce5SPeiming Liu     // TODO: convert more than one?
58406a65ce5SPeiming Liu     return failure();
58506a65ce5SPeiming Liu   }
58606a65ce5SPeiming Liu };
58706a65ce5SPeiming Liu 
588e5999787SAart Bik //===----------------------------------------------------------------------===//
5893426d330SPeiming Liu // Reinterpret Map Rewriters for operations other than linalg.generics
590e5999787SAart Bik //===----------------------------------------------------------------------===//
5917cfac1beSAart Bik 
592c99951d4SPeiming Liu template <typename AllocOp>
593c99951d4SPeiming Liu struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
594c99951d4SPeiming Liu   using OpRewritePattern<AllocOp>::OpRewritePattern;
matchAndRewrite__anonbf08e2170411::TensorAllocDemapper595c99951d4SPeiming Liu   LogicalResult matchAndRewrite(AllocOp op,
596c0d78c42SPeiming Liu                                 PatternRewriter &rewriter) const override {
59706a65ce5SPeiming Liu     if (!hasAnyNonIdentityOperandsOrResults(op))
598c0d78c42SPeiming Liu       return failure();
599c0d78c42SPeiming Liu 
600c0d78c42SPeiming Liu     Location loc = op.getLoc();
601c0d78c42SPeiming Liu     auto stt = getSparseTensorType(op.getResult());
602c0d78c42SPeiming Liu 
603c0d78c42SPeiming Liu     SmallVector<Value> maxDimCrds;
604c0d78c42SPeiming Liu     maxDimCrds.reserve(stt.getDimRank());
605c0d78c42SPeiming Liu     ValueRange dynSz = op.getDynamicSizes();
606c0d78c42SPeiming Liu     for (int64_t dimSz : stt.getDimShape()) {
607c0d78c42SPeiming Liu       if (ShapedType::isDynamic(dimSz)) {
608c0d78c42SPeiming Liu         Value maxCrd = rewriter.create<arith::SubIOp>(
609c0d78c42SPeiming Liu             loc, dynSz.front(), constantIndex(rewriter, loc, 1));
610c0d78c42SPeiming Liu         maxDimCrds.push_back(maxCrd);
611c0d78c42SPeiming Liu         dynSz = dynSz.drop_front();
612c0d78c42SPeiming Liu       } else {
613c0d78c42SPeiming Liu         maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
614c0d78c42SPeiming Liu       }
615c0d78c42SPeiming Liu     }
616c0d78c42SPeiming Liu 
617c0d78c42SPeiming Liu     ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
618c0d78c42SPeiming Liu                                               CrdTransDirectionKind::dim2lvl);
619c0d78c42SPeiming Liu     auto lvlShape = stt.getLvlShape();
620c0d78c42SPeiming Liu     SmallVector<Value> dynLvlSzs;
621c0d78c42SPeiming Liu     for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
622c0d78c42SPeiming Liu       if (ShapedType::isDynamic(lvlShape[i])) {
623c0d78c42SPeiming Liu         Value sz = rewriter.create<arith::AddIOp>(
624c0d78c42SPeiming Liu             loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
625c0d78c42SPeiming Liu         dynLvlSzs.push_back(sz);
626c0d78c42SPeiming Liu       }
627c0d78c42SPeiming Liu     }
628c0d78c42SPeiming Liu 
629c0d78c42SPeiming Liu     assert(dynSz.empty()); // should have consumed all.
6305fcf907bSMatthias Springer     rewriter.startOpModification(op);
631c0d78c42SPeiming Liu     op->setOperands(dynLvlSzs);
632c0d78c42SPeiming Liu     op.getResult().setType(stt.getDemappedType());
6335fcf907bSMatthias Springer     rewriter.finalizeOpModification(op);
634c0d78c42SPeiming Liu     rewriter.setInsertionPointAfter(op);
635c0d78c42SPeiming Liu 
636c0d78c42SPeiming Liu     Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
637c0d78c42SPeiming Liu     rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
638c0d78c42SPeiming Liu     return success();
639c0d78c42SPeiming Liu   }
640c0d78c42SPeiming Liu };
641c0d78c42SPeiming Liu 
6423426d330SPeiming Liu struct TensorInsertDemapper
6433426d330SPeiming Liu     : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
6443426d330SPeiming Liu   using DemapInsRewriter::DemapInsRewriter;
rewriteOp__anonbf08e2170411::TensorInsertDemapper6453426d330SPeiming Liu   LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
646c81a2c05SPeiming Liu                           PatternRewriter &rewriter) const {
64794e27c26SPeiming Liu     if (!hasAnySparseResult(op) || !hasAnyNonIdentityOperandsOrResults(op))
6483426d330SPeiming Liu       return failure();
6493426d330SPeiming Liu 
650ef100c22SPeiming Liu     Location loc = op.getLoc();
651ef100c22SPeiming Liu     auto stt = getSparseTensorType(op.getResult());
652ef100c22SPeiming Liu     ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
653ef100c22SPeiming Liu                                           CrdTransDirectionKind::dim2lvl);
65494e27c26SPeiming Liu     auto insertOp = rewriter.create<tensor::InsertOp>(
655c81a2c05SPeiming Liu         loc, op.getScalar(), adaptor.getDest(), lvlCrd);
6563426d330SPeiming Liu 
6573426d330SPeiming Liu     Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
6583426d330SPeiming Liu     rewriter.replaceOp(op, out);
6593426d330SPeiming Liu     return success();
6603426d330SPeiming Liu   }
6613426d330SPeiming Liu };
6623426d330SPeiming Liu 
66307bf1ddbSPeiming Liu struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {
66407bf1ddbSPeiming Liu   using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anonbf08e2170411::SparseAssembleDemapper66507bf1ddbSPeiming Liu   LogicalResult matchAndRewrite(AssembleOp op,
66607bf1ddbSPeiming Liu                                 PatternRewriter &rewriter) const override {
66707bf1ddbSPeiming Liu     if (!hasAnyNonIdentityOperandsOrResults(op))
66807bf1ddbSPeiming Liu       return failure();
66907bf1ddbSPeiming Liu 
67007bf1ddbSPeiming Liu     assert(hasAnySparseResult(op));
67107bf1ddbSPeiming Liu     auto stt = getSparseTensorType(op.getResult());
67207bf1ddbSPeiming Liu     rewriter.modifyOpInPlace(
67307bf1ddbSPeiming Liu         op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
67407bf1ddbSPeiming Liu     rewriter.setInsertionPointAfter(op);
67507bf1ddbSPeiming Liu     Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
67607bf1ddbSPeiming Liu     rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp());
67707bf1ddbSPeiming Liu     return success();
67807bf1ddbSPeiming Liu   }
67907bf1ddbSPeiming Liu };
68007bf1ddbSPeiming Liu 
68107bf1ddbSPeiming Liu struct SparseDisassembleDemapper
68207bf1ddbSPeiming Liu     : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
68307bf1ddbSPeiming Liu   using DemapInsRewriter::DemapInsRewriter;
rewriteOp__anonbf08e2170411::SparseDisassembleDemapper68407bf1ddbSPeiming Liu   LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
68507bf1ddbSPeiming Liu                           PatternRewriter &rewriter) const {
68607bf1ddbSPeiming Liu     if (!hasAnyNonIdentityOperandsOrResults(op))
68707bf1ddbSPeiming Liu       return failure();
68807bf1ddbSPeiming Liu 
68907bf1ddbSPeiming Liu     assert(hasAnySparseOperandOrResult(op));
69007bf1ddbSPeiming Liu     rewriter.modifyOpInPlace(op, [&op, &adaptor]() {
69107bf1ddbSPeiming Liu       op.getTensorMutable().assign(adaptor.getTensor());
69207bf1ddbSPeiming Liu     });
69307bf1ddbSPeiming Liu     return success();
69407bf1ddbSPeiming Liu   }
69507bf1ddbSPeiming Liu };
69607bf1ddbSPeiming Liu 
6973426d330SPeiming Liu struct ForeachOpDemapper
6983426d330SPeiming Liu     : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
6993426d330SPeiming Liu   using DemapInsRewriter::DemapInsRewriter;
rewriteOp__anonbf08e2170411::ForeachOpDemapper7003426d330SPeiming Liu   LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
7013426d330SPeiming Liu                           PatternRewriter &rewriter) const {
7023426d330SPeiming Liu     // Only handle operations with sparse input/output with non-identity dim2lvl
7033426d330SPeiming Liu     // maps.
70406a65ce5SPeiming Liu     if (!hasAnyNonIdentityOperandsOrResults(op))
7053426d330SPeiming Liu       return failure();
7063426d330SPeiming Liu 
7073426d330SPeiming Liu     // TODO: demap constant as well.
7083426d330SPeiming Liu     if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
7093426d330SPeiming Liu       if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
7103426d330SPeiming Liu         return failure();
7113426d330SPeiming Liu 
7123426d330SPeiming Liu     Location loc = op.getLoc();
7133426d330SPeiming Liu     // Cache the type information since we update the foreach op in-place.
7143426d330SPeiming Liu     auto srcStt = getSparseTensorType(op.getTensor());
7153426d330SPeiming Liu     SmallVector<Type> prevRetTps(op.getResultTypes());
7163426d330SPeiming Liu 
7175fcf907bSMatthias Springer     rewriter.startOpModification(op);
7183426d330SPeiming Liu     op.getTensorMutable().assign(adaptor.getTensor());
7193426d330SPeiming Liu     op.getInitArgsMutable().assign(adaptor.getInitArgs());
7203426d330SPeiming Liu     // Update results' types.
7213426d330SPeiming Liu     for (auto r : op.getResults())
7223426d330SPeiming Liu       if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
7233426d330SPeiming Liu         r.setType(stt->getDemappedType());
7243426d330SPeiming Liu 
7253426d330SPeiming Liu     Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
7263426d330SPeiming Liu     // Update the foreach body.
7273426d330SPeiming Liu     SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
7283426d330SPeiming Liu     blockArgTps.push_back(srcStt.getElementType());
7293426d330SPeiming Liu     blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
7303426d330SPeiming Liu                        adaptor.getInitArgs().getTypes().end());
7313426d330SPeiming Liu     Block *body = op.getBody();
7323426d330SPeiming Liu     // Block Args: [dimCrd, val, initArgs]
7333426d330SPeiming Liu     unsigned preArgNum = body->getNumArguments();
7343426d330SPeiming Liu     for (Type t : blockArgTps)
7353426d330SPeiming Liu       body->addArgument(t, loc);
7363426d330SPeiming Liu 
7373426d330SPeiming Liu     // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
7383426d330SPeiming Liu     rewriter.setInsertionPointToStart(body);
7393426d330SPeiming Liu     ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
7403426d330SPeiming Liu 
7413426d330SPeiming Liu     ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
7423426d330SPeiming Liu                                               CrdTransDirectionKind::lvl2dim);
7433426d330SPeiming Liu     rewriter.replaceAllUsesWith(
7443426d330SPeiming Liu         body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
7453426d330SPeiming Liu     body->eraseArguments(0, srcStt.getDimRank());
7463426d330SPeiming Liu     // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
7473426d330SPeiming Liu     unsigned numInitArgs = op.getInitArgs().size();
7483426d330SPeiming Liu     rewriter.replaceAllUsesWith(body->getArgument(0),
7493426d330SPeiming Liu                                 body->getArgument(lvlRank + numInitArgs + 1));
7503426d330SPeiming Liu     body->eraseArgument(0);
7513426d330SPeiming Liu     // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
7523426d330SPeiming Liu     ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
7533426d330SPeiming Liu     ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
7543426d330SPeiming Liu     // Remap back before replacement.
7553426d330SPeiming Liu     SmallVector<Value> reMappedArgs =
7563426d330SPeiming Liu         remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
7573426d330SPeiming Liu     rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
7583426d330SPeiming Liu     body->eraseArguments(0, numInitArgs);
7593426d330SPeiming Liu     // Block Args: [lvlCrds, DemappedArgs] and we are done.
7603426d330SPeiming Liu 
7613426d330SPeiming Liu     // Update yield operations.
7623426d330SPeiming Liu     if (numInitArgs != 0) {
7633426d330SPeiming Liu       rewriter.setInsertionPointToEnd(body);
7643426d330SPeiming Liu       auto yield = llvm::cast<YieldOp>(body->getTerminator());
765a54930e6SPeiming Liu       if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
7663426d330SPeiming Liu           stt && !stt->isIdentity()) {
767a54930e6SPeiming Liu         Value y =
768a54930e6SPeiming Liu             genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
7693426d330SPeiming Liu         rewriter.create<YieldOp>(loc, y);
7703426d330SPeiming Liu         rewriter.eraseOp(yield);
7713426d330SPeiming Liu       }
7723426d330SPeiming Liu     }
7735fcf907bSMatthias Springer     rewriter.finalizeOpModification(op);
7743426d330SPeiming Liu 
7753426d330SPeiming Liu     rewriter.setInsertionPointAfter(op);
7763426d330SPeiming Liu     SmallVector<Value> outs =
7773426d330SPeiming Liu         remapValueRange(rewriter, prevRetTps, op.getResults());
7783426d330SPeiming Liu 
7793426d330SPeiming Liu     // Replace all the uses of the foreach results, expect the use in
7803426d330SPeiming Liu     // reinterpret_map used to remap the output.
7813426d330SPeiming Liu     for (auto [from, to] : llvm::zip(op.getResults(), outs))
7823426d330SPeiming Liu       rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
7833426d330SPeiming Liu 
7843426d330SPeiming Liu     return success();
785ef100c22SPeiming Liu   }
786ef100c22SPeiming Liu };
787ef100c22SPeiming Liu 
7887cfac1beSAart Bik } // namespace
7897cfac1beSAart Bik 
populateSparseReinterpretMap(RewritePatternSet & patterns,ReinterpretMapScope scope)7906a93da99SPeiming Liu void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
791ef100c22SPeiming Liu                                         ReinterpretMapScope scope) {
792ef100c22SPeiming Liu   if (scope == ReinterpretMapScope::kAll ||
793e5999787SAart Bik       scope == ReinterpretMapScope::kGenericOnly) {
79406a65ce5SPeiming Liu     patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
79506a65ce5SPeiming Liu         patterns.getContext());
796e5999787SAart Bik   }
797e5999787SAart Bik   if (scope == ReinterpretMapScope::kAll ||
798ef100c22SPeiming Liu       scope == ReinterpretMapScope::kExceptGeneric) {
799c99951d4SPeiming Liu     patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
80007bf1ddbSPeiming Liu                  TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
80107bf1ddbSPeiming Liu                  SparseDisassembleDemapper, TensorInsertDemapper,
802c99951d4SPeiming Liu                  ForeachOpDemapper>(patterns.getContext());
803ef100c22SPeiming Liu   }
804ef100c22SPeiming Liu }
805