106a65ce5SPeiming Liu //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
27cfac1beSAart Bik //
37cfac1beSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47cfac1beSAart Bik // See https://llvm.org/LICENSE.txt for license information.
57cfac1beSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67cfac1beSAart Bik //
77cfac1beSAart Bik //===----------------------------------------------------------------------===//
87cfac1beSAart Bik
9365777ecSAart Bik #include "Utils/CodegenUtils.h"
10365777ecSAart Bik #include "Utils/IterationGraphSorter.h"
11c0d78c42SPeiming Liu
12ef100c22SPeiming Liu #include "mlir/Dialect/Affine/IR/AffineOps.h"
13c0d78c42SPeiming Liu #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14e5999787SAart Bik #include "mlir/Dialect/Linalg/IR/Linalg.h"
15e5999787SAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h"
167cfac1beSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
177cfac1beSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
187cfac1beSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19ef100c22SPeiming Liu #include "mlir/Dialect/Tensor/IR/Tensor.h"
20c99951d4SPeiming Liu #include "mlir/IR/AffineExprVisitor.h"
21ef100c22SPeiming Liu #include "mlir/IR/AffineMap.h"
22ef100c22SPeiming Liu
23ef100c22SPeiming Liu using namespace mlir;
24ef100c22SPeiming Liu using namespace mlir::sparse_tensor;
257cfac1beSAart Bik
26c99951d4SPeiming Liu namespace {
27c99951d4SPeiming Liu
28c99951d4SPeiming Liu //===----------------------------------------------------------------------===//
29c99951d4SPeiming Liu // File Local Helper classes.
30c99951d4SPeiming Liu //===----------------------------------------------------------------------===//
31c99951d4SPeiming Liu
32c99951d4SPeiming Liu // CRTP to help implementing a rewriter that demaps all its inputs.
33c99951d4SPeiming Liu template <typename SubClass, typename SourceOp>
34c99951d4SPeiming Liu struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
35c99951d4SPeiming Liu using OpRewritePattern<SourceOp>::OpRewritePattern;
36c99951d4SPeiming Liu using OpAdaptor = typename SourceOp::Adaptor;
37c99951d4SPeiming Liu
matchAndRewrite__anonbf08e2170111::DemapInsRewriter38c99951d4SPeiming Liu LogicalResult matchAndRewrite(SourceOp op,
39c99951d4SPeiming Liu PatternRewriter &rewriter) const override {
40c99951d4SPeiming Liu Location loc = op.getLoc();
41986287e7SMatthias Springer
42c99951d4SPeiming Liu // Demaps non-trivial inputs.
43986287e7SMatthias Springer bool changed = false;
44c99951d4SPeiming Liu SmallVector<Value> deMappedIns(op->getOperands());
45986287e7SMatthias Springer for (Value &in : deMappedIns) {
46986287e7SMatthias Springer if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
47c99951d4SPeiming Liu in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
48986287e7SMatthias Springer changed = true;
49986287e7SMatthias Springer }
50986287e7SMatthias Springer }
51c99951d4SPeiming Liu
52c99951d4SPeiming Liu // CRTP call.
53c99951d4SPeiming Liu OpAdaptor adaptor(deMappedIns, op);
54986287e7SMatthias Springer LogicalResult status =
55986287e7SMatthias Springer static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
56986287e7SMatthias Springer return changed ? success() : status;
57c99951d4SPeiming Liu }
58c99951d4SPeiming Liu };
59c99951d4SPeiming Liu
60c99951d4SPeiming Liu // Flattens an affine expression into a list of AffineDimExprs.
61c99951d4SPeiming Liu struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
AffineDimCollector__anonbf08e2170111::AffineDimCollector62c99951d4SPeiming Liu explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
visitDimExpr__anonbf08e2170111::AffineDimCollector63c99951d4SPeiming Liu void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
64c99951d4SPeiming Liu BitVector dims;
65c99951d4SPeiming Liu };
66c99951d4SPeiming Liu
67c99951d4SPeiming Liu // Flattens an affine expression into a list of AffineDimExprs.
68c99951d4SPeiming Liu struct AffineExprAdmissibleVisitor
69c99951d4SPeiming Liu : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
AffineExprAdmissibleVisitor__anonbf08e2170111::AffineExprAdmissibleVisitor70c99951d4SPeiming Liu explicit AffineExprAdmissibleVisitor(bool isOutput)
71c99951d4SPeiming Liu : admissible(true), isOutput(isOutput){};
72c99951d4SPeiming Liu
73c99951d4SPeiming Liu // We only allow AffineDimExpr on output.
visitAddExpr__anonbf08e2170111::AffineExprAdmissibleVisitor74c99951d4SPeiming Liu void visitAddExpr(AffineBinaryOpExpr expr) {
75c99951d4SPeiming Liu if (isOutput)
76c99951d4SPeiming Liu admissible = false;
77c99951d4SPeiming Liu }
visitMulExpr__anonbf08e2170111::AffineExprAdmissibleVisitor78c99951d4SPeiming Liu void visitMulExpr(AffineBinaryOpExpr expr) {
79c99951d4SPeiming Liu if (isOutput)
80c99951d4SPeiming Liu admissible = false;
81c99951d4SPeiming Liu }
82c99951d4SPeiming Liu
83c99951d4SPeiming Liu // We disallow mod, floor div and ceil div on inputs.
visitModExpr__anonbf08e2170111::AffineExprAdmissibleVisitor84c99951d4SPeiming Liu void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
visitFloorDivExpr__anonbf08e2170111::AffineExprAdmissibleVisitor85c99951d4SPeiming Liu void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
visitCeilDivExpr__anonbf08e2170111::AffineExprAdmissibleVisitor86c99951d4SPeiming Liu void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
operator bool__anonbf08e2170111::AffineExprAdmissibleVisitor87c99951d4SPeiming Liu operator bool() { return admissible; }
88c99951d4SPeiming Liu
89c99951d4SPeiming Liu private:
90c99951d4SPeiming Liu bool admissible;
91c99951d4SPeiming Liu bool isOutput;
92c99951d4SPeiming Liu };
93c99951d4SPeiming Liu
94c99951d4SPeiming Liu // The first BitVector stores levels where inadmissible exprs are used.
95c99951d4SPeiming Liu // The second BitVector stores the AffineDimExp that are used by the
96c99951d4SPeiming Liu // inadmissible expressions.
97c99951d4SPeiming Liu using InadmissInfo = std::pair<BitVector, BitVector>;
98c99951d4SPeiming Liu
99c99951d4SPeiming Liu } // namespace
100c99951d4SPeiming Liu
101e5999787SAart Bik //===----------------------------------------------------------------------===//
1023426d330SPeiming Liu // File Local Helper methods.
103e5999787SAart Bik //===----------------------------------------------------------------------===//
104e5999787SAart Bik
105c99951d4SPeiming Liu // Collects the inadmissible affine expression imposed on levels.
collectInadmissInfo(AffineMap map,bool isOutput)106c99951d4SPeiming Liu static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
107c99951d4SPeiming Liu auto ret = std::make_pair(BitVector(map.getNumResults()),
108c99951d4SPeiming Liu BitVector(map.getNumDims()));
109c99951d4SPeiming Liu AffineDimCollector collector(map.getNumDims());
110c99951d4SPeiming Liu for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
111c99951d4SPeiming Liu AffineExprAdmissibleVisitor admissible(isOutput);
112c99951d4SPeiming Liu admissible.walkPostOrder(map.getResult(lvl));
113c99951d4SPeiming Liu if (!admissible) {
114c99951d4SPeiming Liu // Record the inadmissible level.
115c99951d4SPeiming Liu ret.first.set(lvl);
116c99951d4SPeiming Liu // Record the AffineDimExpr that is used in the inadmissible expr.
117c99951d4SPeiming Liu collector.walkPostOrder(map.getResult(lvl));
118e5999787SAart Bik }
119c99951d4SPeiming Liu }
120c99951d4SPeiming Liu ret.second = collector.dims;
121c99951d4SPeiming Liu return ret;
122c99951d4SPeiming Liu }
123c99951d4SPeiming Liu
124c99951d4SPeiming Liu // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
125c99951d4SPeiming Liu // inadmissible affine expressions can be eliminated.
126c99951d4SPeiming Liu // For example, we can rewrite
127c99951d4SPeiming Liu // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
128c99951d4SPeiming Liu // to
129c99951d4SPeiming Liu // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
130c99951d4SPeiming Liu // by composing inverse(idxMap), that is
131c99951d4SPeiming Liu // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
132c99951d4SPeiming Liu // -> ((l0 * 2 + l2) floordiv 2,
133c99951d4SPeiming Liu // (l1 * 3 + l3) floordiv 3,
134c99951d4SPeiming Liu // (l0 * 2 + l2) mod 2,
135c99951d4SPeiming Liu // (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
136c99951d4SPeiming Liu //
137c99951d4SPeiming Liu // This function builds the inverse(idxMap) that replace every dimensions used
138c99951d4SPeiming Liu // in `info` to levels, and updates the iterator type array `itTps` for the new
139c99951d4SPeiming Liu // index variable introduced.
140c99951d4SPeiming Liu //
141c99951d4SPeiming Liu // Note that the returned affine map does not retain the order of the input
142c99951d4SPeiming Liu // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
143c99951d4SPeiming Liu // replaced levels, and remaining ones for unused dimensions.
144c99951d4SPeiming Liu // For example, to handle
145c99951d4SPeiming Liu // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
146c99951d4SPeiming Liu // which is a typical map for block_2to4. The function returns:
147c99951d4SPeiming Liu // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
148c99951d4SPeiming Liu // in which, (l0, l1) together replaces `d1`, yet they appear
149c99951d4SPeiming Liu // before `d0` in the resulting affine map.
150c99951d4SPeiming Liu // The index (loop) order can later be canonicalized by a topo sort.
151c99951d4SPeiming Liu static AffineMap
genReplaceDimToLvlMap(const InadmissInfo & info,AffineMap idxMap,SmallVector<utils::IteratorType> & itTps)152c99951d4SPeiming Liu genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
153c99951d4SPeiming Liu SmallVector<utils::IteratorType> &itTps) {
154c99951d4SPeiming Liu MLIRContext *ctx = idxMap.getContext();
155c99951d4SPeiming Liu auto [inAdLvls, usedDims] = info;
156c99951d4SPeiming Liu // Note that idxMap does not equal to dim2Lvl map, it is computed by
157c99951d4SPeiming Liu // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
158c99951d4SPeiming Liu // ID map.
159c99951d4SPeiming Liu // TODO: we might fail here, in those case we should really return
160c99951d4SPeiming Liu // failure instead of assertion error.
161c99951d4SPeiming Liu auto lvl2Idx = inferLvlToDim(idxMap, ctx);
162c99951d4SPeiming Liu
163c99951d4SPeiming Liu assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
164c99951d4SPeiming Liu if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
165c99951d4SPeiming Liu // This could happen when some dimensions are projected.
166c99951d4SPeiming Liu // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
167c99951d4SPeiming Liu // ==> lvl2Idx = (j, k) -> (j, k)
168c99951d4SPeiming Liu // In this case, we append the unused dimesion at the end.
169c99951d4SPeiming Liu // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
170c99951d4SPeiming Liu SmallVector<AffineExpr> results;
171c99951d4SPeiming Liu AffineDimCollector usedInLvl(idxMap.getNumDims());
172c99951d4SPeiming Liu for (auto e : idxMap.getResults())
173c99951d4SPeiming Liu usedInLvl.walkPostOrder(e);
174c99951d4SPeiming Liu
175c99951d4SPeiming Liu unsigned curUsedDimID = 0;
176c99951d4SPeiming Liu unsigned curUnusedDimID = lvl2Idx.getNumDims();
177c99951d4SPeiming Liu
178c99951d4SPeiming Liu BitVector unused = usedInLvl.dims.flip();
179c99951d4SPeiming Liu for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
180c99951d4SPeiming Liu if (unused.test(i))
181c99951d4SPeiming Liu results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
182c99951d4SPeiming Liu else
183c99951d4SPeiming Liu results.push_back(lvl2Idx.getResult(curUsedDimID++));
184c99951d4SPeiming Liu }
185c99951d4SPeiming Liu lvl2Idx =
186c99951d4SPeiming Liu AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
187c99951d4SPeiming Liu }
188c99951d4SPeiming Liu assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
189c99951d4SPeiming Liu
190c99951d4SPeiming Liu // We do not need to replace the DimExpr that is not used in inadmissible
191c99951d4SPeiming Liu // level expressions. We use the first inAdLvl.count() dim to represent the
192c99951d4SPeiming Liu // replaced level, the remainings are reserved for unchanged ones.
193c99951d4SPeiming Liu // Note that results from the inverse map computed previously does not follow
194c99951d4SPeiming Liu // the convention we used, and we need to fix the mismatch below.
195c99951d4SPeiming Liu unsigned curRepID = 0;
196c99951d4SPeiming Liu unsigned curOriID = inAdLvls.count();
197c99951d4SPeiming Liu SmallVector<AffineExpr> results;
198c99951d4SPeiming Liu SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr());
199c99951d4SPeiming Liu SmallVector<utils::IteratorType> transItTps;
200c99951d4SPeiming Liu
201c99951d4SPeiming Liu for (unsigned l : inAdLvls.set_bits()) {
202c99951d4SPeiming Liu // By our convention, the inadmissible level `l` always appears in the
203c99951d4SPeiming Liu // leading part (accumulated by curRepID) of the affine map's parameter
204c99951d4SPeiming Liu // list. Record the mapping so that we can replace all the uses of `l` to
205c99951d4SPeiming Liu // the correct position after the translation.
206c99951d4SPeiming Liu dimRep[l] = getAffineDimExpr(curRepID++, ctx);
207c99951d4SPeiming Liu // A new index variable is introduced for the inadmissible level, inherit
208c99951d4SPeiming Liu // the iterator type. E.g., if l0 = d0 floordiv 2, the
209c99951d4SPeiming Liu // iterator type of l0 equals to the iterator type of d0.
210c99951d4SPeiming Liu AffineExpr lvlExp = idxMap.getResult(l);
211c99951d4SPeiming Liu AffineDimCollector collector(idxMap.getNumDims());
212c99951d4SPeiming Liu collector.walkPostOrder(lvlExp);
213c99951d4SPeiming Liu // We assumes a level can only be derived from one dimension.
214c99951d4SPeiming Liu assert(collector.dims.count() == 1);
215c99951d4SPeiming Liu transItTps.push_back(itTps[collector.dims.find_first()]);
216c99951d4SPeiming Liu }
217c99951d4SPeiming Liu
218c99951d4SPeiming Liu for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
219c99951d4SPeiming Liu if (usedDims.test(d)) {
220c99951d4SPeiming Liu // The dimension is used in some of the inadmissible levels, and it need
221c99951d4SPeiming Liu // to be inversed. Get the inversion from the inverse map, and fix the
222c99951d4SPeiming Liu // mismatch captured by the above loop.
223c99951d4SPeiming Liu results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
224c99951d4SPeiming Liu } else {
225c99951d4SPeiming Liu // The dimension is not used in any of the inadmissible levels, and it
226c99951d4SPeiming Liu // does not need to be inversed. Fix the mismatch by mapping it to the
227c99951d4SPeiming Liu // trailing part of the affine map (accumulated by curOriID).
228c99951d4SPeiming Liu results.push_back(getAffineDimExpr(curOriID++, ctx));
229c99951d4SPeiming Liu transItTps.push_back(itTps[d]);
230c99951d4SPeiming Liu }
231c99951d4SPeiming Liu }
232c99951d4SPeiming Liu unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
233c99951d4SPeiming Liu // Update iterator type.
234c99951d4SPeiming Liu itTps.assign(transItTps.begin(), transItTps.end());
235c99951d4SPeiming Liu return AffineMap::get(numDim, 0, results, ctx);
236c99951d4SPeiming Liu }
237c99951d4SPeiming Liu
238c99951d4SPeiming Liu // Translates the index map in the linalg::GenericOp from idx->dim map to
239c99951d4SPeiming Liu // idx->lvl map. Returns failure if the index map can not be translated to an
240c99951d4SPeiming Liu // admissible form.
241c99951d4SPeiming Liu // Returns the translated index map array and the iterator type array.
242c99951d4SPeiming Liu static std::optional<std::pair<ArrayAttr, ArrayAttr>>
translateMap(linalg::GenericOp op,PatternRewriter & rewriter)243c99951d4SPeiming Liu translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
244c99951d4SPeiming Liu // idxMap is a idx2dim map before reinterpretation.
245c99951d4SPeiming Liu MLIRContext *ctx = op.getContext();
246c99951d4SPeiming Liu SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
247c99951d4SPeiming Liu SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
248c99951d4SPeiming Liu for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
249c99951d4SPeiming Liu Value tensor = op->getOpOperand(i).get();
250c99951d4SPeiming Liu auto stt = tryGetSparseTensorType(tensor);
251c99951d4SPeiming Liu if (stt && !stt->isIdentity()) {
252c99951d4SPeiming Liu AffineMap dim2Lvl = stt->getDimToLvl();
253c99951d4SPeiming Liu // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
254c99951d4SPeiming Liu idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
255c99951d4SPeiming Liu }
256c99951d4SPeiming Liu }
257c99951d4SPeiming Liu
258c99951d4SPeiming Liu // A naive way to handle common constant expressions that arise during dim2lvl
259c99951d4SPeiming Liu // translation.
260c99951d4SPeiming Liu auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
261c99951d4SPeiming Liu unsigned pos, int64_t lvlSz) {
262c99951d4SPeiming Liu if (!ShapedType::isDynamic(lvlSz)) {
263c99951d4SPeiming Liu auto c0 = getAffineConstantExpr(0, ctx);
264c99951d4SPeiming Liu auto lvlExp = getAffineDimExpr(pos, ctx);
265c99951d4SPeiming Liu auto szExp = getAffineConstantExpr(lvlSz, ctx);
266c99951d4SPeiming Liu
267c99951d4SPeiming Liu // lvl floordiv lvlSz = 0
268c99951d4SPeiming Liu auto divExp =
269c99951d4SPeiming Liu getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp);
270c99951d4SPeiming Liu cstMapping.try_emplace(divExp, c0);
271c99951d4SPeiming Liu
272c99951d4SPeiming Liu // lvl mod lvlSz = lvl
273c99951d4SPeiming Liu auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
274c99951d4SPeiming Liu cstMapping.try_emplace(modExp, lvlExp);
275c99951d4SPeiming Liu }
276c99951d4SPeiming Liu };
277c99951d4SPeiming Liu
278c99951d4SPeiming Liu unsigned boundedNum = 0;
279c99951d4SPeiming Liu // A fixed-point algorithm.
280c99951d4SPeiming Liu bool changed = true;
281c99951d4SPeiming Liu while (changed) {
282c99951d4SPeiming Liu changed = false;
283c99951d4SPeiming Liu for (OpOperand &operand : op->getOpOperands()) {
284c99951d4SPeiming Liu auto stt = tryGetSparseTensorType(operand.get());
285c99951d4SPeiming Liu // Skip on dense operands.
286c99951d4SPeiming Liu if (!stt || !stt->getEncoding())
287c99951d4SPeiming Liu continue;
288c99951d4SPeiming Liu
289c99951d4SPeiming Liu unsigned tid = operand.getOperandNumber();
290c99951d4SPeiming Liu bool isOutput = &operand == op.getDpsInitOperand(0);
291c99951d4SPeiming Liu AffineMap idxMap = idxMapArray[tid];
292c99951d4SPeiming Liu InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
293c99951d4SPeiming Liu auto [inAdLvls, dimExprs] = inAdInfo;
294c99951d4SPeiming Liu for (unsigned d : dimExprs.set_bits()) {
295c99951d4SPeiming Liu // The first `boundedNum` used in the AffineMap is introduced to
296c99951d4SPeiming Liu // resolve previous inadmissible expressions. We can not replace them
297c99951d4SPeiming Liu // as it might bring back the inadmissible expressions.
298c99951d4SPeiming Liu if (d < boundedNum)
299c99951d4SPeiming Liu return std::nullopt;
300c99951d4SPeiming Liu }
301c99951d4SPeiming Liu
302c99951d4SPeiming Liu if (inAdLvls.count() != 0) {
303c99951d4SPeiming Liu // Naive constant progagation, should be sufficient to handle block
304c99951d4SPeiming Liu // sparsity in our cases.
305c99951d4SPeiming Liu SmallVector<int64_t> lvlShape = stt->getLvlShape();
306c99951d4SPeiming Liu DenseMap<AffineExpr, AffineExpr> cstMapping;
307c99951d4SPeiming Liu unsigned position = 0;
308c99951d4SPeiming Liu for (unsigned lvl : inAdLvls.set_bits()) {
309c99951d4SPeiming Liu int64_t lvlSz = lvlShape[lvl];
310c99951d4SPeiming Liu populateCstMapping(cstMapping, position, lvlSz);
311c99951d4SPeiming Liu position++;
312c99951d4SPeiming Liu }
313c99951d4SPeiming Liu
314c99951d4SPeiming Liu AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
315c99951d4SPeiming Liu // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
316c99951d4SPeiming Liu // inadmissible expressions.
317c99951d4SPeiming Liu for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
318c99951d4SPeiming Liu AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
319c99951d4SPeiming Liu idxMapArray[tid] = transMap.replace(
320c99951d4SPeiming Liu cstMapping, /*numResultDims=*/transMap.getNumDims(),
321c99951d4SPeiming Liu /*numResultSyms=*/0);
322c99951d4SPeiming Liu }
323c99951d4SPeiming Liu changed = true;
324c99951d4SPeiming Liu boundedNum += inAdLvls.count();
325c99951d4SPeiming Liu }
326c99951d4SPeiming Liu }
327c99951d4SPeiming Liu };
328c99951d4SPeiming Liu
329c99951d4SPeiming Liu SmallVector<Attribute> iterAttr =
330c99951d4SPeiming Liu llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
331c99951d4SPeiming Liu return linalg::IteratorTypeAttr::get(ctx, itTp);
332c99951d4SPeiming Liu });
333c99951d4SPeiming Liu
334c99951d4SPeiming Liu return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
335c99951d4SPeiming Liu rewriter.getArrayAttr(iterAttr));
336e5999787SAart Bik }
337e5999787SAart Bik
338e5999787SAart Bik // Generates a "de"mapping reinterpretation of the map.
genDemap(OpBuilder & builder,SparseTensorEncodingAttr enc,Value val)339e5999787SAart Bik static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
340e5999787SAart Bik Value val) {
341e5999787SAart Bik return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
342e5999787SAart Bik val);
343e5999787SAart Bik }
344e5999787SAart Bik
345e5999787SAart Bik // Generates a "re"mapping reinterpretation of the map.
genRemap(OpBuilder & builder,SparseTensorEncodingAttr enc,Value val)346e5999787SAart Bik static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
347e5999787SAart Bik Value val) {
348e5999787SAart Bik return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
349e5999787SAart Bik }
350e5999787SAart Bik
remapValueRange(OpBuilder & rewriter,TypeRange types,ValueRange outs)3513426d330SPeiming Liu static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
3523426d330SPeiming Liu ValueRange outs) {
3533426d330SPeiming Liu SmallVector<Value> ret(outs);
3543426d330SPeiming Liu assert(outs.size() == types.size());
3553426d330SPeiming Liu for (auto [r, t] : llvm::zip(ret, types))
3563426d330SPeiming Liu if (r.getType() != t)
3573426d330SPeiming Liu r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
3583426d330SPeiming Liu return ret;
3593426d330SPeiming Liu }
3603426d330SPeiming Liu
3613426d330SPeiming Liu namespace {
3623426d330SPeiming Liu
363e5999787SAart Bik //===----------------------------------------------------------------------===//
364e5999787SAart Bik // Rewriting rules for linalg generic ops.
365e5999787SAart Bik //===----------------------------------------------------------------------===//
366e5999787SAart Bik
367e5999787SAart Bik /// Sparse rewriting rule for the generic `linalg` operation.
368c99951d4SPeiming Liu struct GenericOpReinterpretMap
369c99951d4SPeiming Liu : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
370e5999787SAart Bik public:
371c99951d4SPeiming Liu using DemapInsRewriter::DemapInsRewriter;
rewriteOp__anonbf08e2170411::GenericOpReinterpretMap372c99951d4SPeiming Liu LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
373c99951d4SPeiming Liu PatternRewriter &rewriter) const {
374c99951d4SPeiming Liu // Only rewrite single output operations with pure (sparse) tensor
375c99951d4SPeiming Liu // semantics.
3760a8e3dd4SMatthias Springer if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
377c99951d4SPeiming Liu !hasAnySparseOperandOrResult(linalgOp) ||
37806a65ce5SPeiming Liu !hasAnyNonIdentityOperandsOrResults(linalgOp))
379c99951d4SPeiming Liu return failure();
380e5999787SAart Bik
381c99951d4SPeiming Liu // Try translating the index map.
382c99951d4SPeiming Liu auto transMap = translateMap(linalgOp, rewriter);
383c99951d4SPeiming Liu if (!transMap)
384c99951d4SPeiming Liu return rewriter.notifyMatchFailure(
385c99951d4SPeiming Liu linalgOp, "the sparse kernel can not be sparsified.");
386c99951d4SPeiming Liu
387c99951d4SPeiming Liu // On success, replace update the linalg operands and maps in place.
388c99951d4SPeiming Liu Value res = linalgOp.getResult(0);
389c99951d4SPeiming Liu auto stt = tryGetSparseTensorType(res);
390c99951d4SPeiming Liu auto [idxMap, itTp] = *transMap;
391c99951d4SPeiming Liu
3925fcf907bSMatthias Springer rewriter.startOpModification(linalgOp);
393c99951d4SPeiming Liu linalgOp.setIndexingMapsAttr(idxMap);
394c99951d4SPeiming Liu linalgOp.setIteratorTypesAttr(itTp);
395c99951d4SPeiming Liu // Use demapped arguments.
396c99951d4SPeiming Liu linalgOp.getInputsMutable().assign(adaptor.getInputs());
397c99951d4SPeiming Liu linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
398c99951d4SPeiming Liu res.setType(adaptor.getOutputs()[0].getType());
3995fcf907bSMatthias Springer rewriter.finalizeOpModification(linalgOp);
400c99951d4SPeiming Liu
401c99951d4SPeiming Liu rewriter.setInsertionPointAfter(linalgOp);
402b19c40c5SAart Bik if (stt && stt->hasEncoding()) {
403c99951d4SPeiming Liu Value t = genRemap(rewriter, stt->getEncoding(), res);
404c99951d4SPeiming Liu rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
405e5999787SAart Bik }
406e5999787SAart Bik return success();
407e5999787SAart Bik }
408e5999787SAart Bik };
409e5999787SAart Bik
41006a65ce5SPeiming Liu struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
41106a65ce5SPeiming Liu using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anonbf08e2170411::GenericOpScheduler41206a65ce5SPeiming Liu LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
41306a65ce5SPeiming Liu PatternRewriter &rewriter) const override {
4140a8e3dd4SMatthias Springer if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
41506a65ce5SPeiming Liu hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
41606a65ce5SPeiming Liu !hasAnySparseOperandOrResult(linalgOp)) {
41706a65ce5SPeiming Liu return failure();
41806a65ce5SPeiming Liu }
41906a65ce5SPeiming Liu
42006a65ce5SPeiming Liu const StringRef sorted = "sorted";
42106a65ce5SPeiming Liu if (linalgOp->hasAttr(sorted))
42206a65ce5SPeiming Liu return failure();
42306a65ce5SPeiming Liu
42406a65ce5SPeiming Liu auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
42506a65ce5SPeiming Liu bool isAdmissible = false;
42606a65ce5SPeiming Liu AffineMap order;
42706a65ce5SPeiming Liu // A const list of all masks that we used for iteration graph
42806a65ce5SPeiming Liu // computation. Must be ordered from more strict to less strict.
42906a65ce5SPeiming Liu // Ideally (though might not be guaranteed), the earlier a constraint mask
43006a65ce5SPeiming Liu // can be satisfied, the faster the generated kernel will be.
4314e2f1521SPeiming Liu const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
4324e2f1521SPeiming Liu SortMask::kIncludeDenseInput,
4334e2f1521SPeiming Liu SortMask::kIncludeDenseOutput,
4344e2f1521SPeiming Liu SortMask::kSparseOnly};
43506a65ce5SPeiming Liu for (const SortMask mask : allMasks) {
43606a65ce5SPeiming Liu order = scheduler.sort(mask);
43706a65ce5SPeiming Liu if (order) {
43806a65ce5SPeiming Liu if (isAdmissibleOrder(linalgOp, order)) {
43906a65ce5SPeiming Liu isAdmissible = true;
44006a65ce5SPeiming Liu break;
44106a65ce5SPeiming Liu }
44206a65ce5SPeiming Liu // else try a set of less strict constraints.
44306a65ce5SPeiming Liu }
44406a65ce5SPeiming Liu }
44506a65ce5SPeiming Liu
44606a65ce5SPeiming Liu if (!order) {
44706a65ce5SPeiming Liu // Cycles detected.
44806a65ce5SPeiming Liu if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
44906a65ce5SPeiming Liu return rewriter.notifyMatchFailure(
45006a65ce5SPeiming Liu linalgOp, "the sparse kernel can not be scheduled: loop detected.");
45106a65ce5SPeiming Liu }
45206a65ce5SPeiming Liu return success();
45306a65ce5SPeiming Liu }
45406a65ce5SPeiming Liu
45506a65ce5SPeiming Liu if (!isAdmissible) {
45606a65ce5SPeiming Liu return rewriter.notifyMatchFailure(
45706a65ce5SPeiming Liu linalgOp, "the sparse kernel can not be scheduled.");
45806a65ce5SPeiming Liu }
45906a65ce5SPeiming Liu
46006a65ce5SPeiming Liu // Marks the GenericOp to avoid recursive matching.
4615fcf907bSMatthias Springer rewriter.modifyOpInPlace(linalgOp, [&]() {
46206a65ce5SPeiming Liu linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
463986287e7SMatthias Springer });
46406a65ce5SPeiming Liu
46506a65ce5SPeiming Liu // Already sorted.
46606a65ce5SPeiming Liu if (order.isIdentity())
467986287e7SMatthias Springer return success();
46806a65ce5SPeiming Liu
46906a65ce5SPeiming Liu assert(order.isPermutation());
47006a65ce5SPeiming Liu // `order` is orignial loop -> sorted loop map
47106a65ce5SPeiming Liu ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
47206a65ce5SPeiming Liu SmallVector<Attribute> curItTypes;
47306a65ce5SPeiming Liu curItTypes.reserve(preItTypes.size());
47406a65ce5SPeiming Liu for (AffineExpr expr : order.getResults()) {
47506a65ce5SPeiming Liu unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
47606a65ce5SPeiming Liu curItTypes.push_back(preItTypes[loopID]);
47706a65ce5SPeiming Liu }
47806a65ce5SPeiming Liu
47906a65ce5SPeiming Liu // Inverse `order` to get sorted loop -> original loop map
48006a65ce5SPeiming Liu order = inversePermutation(order);
48106a65ce5SPeiming Liu SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
48206a65ce5SPeiming Liu for (AffineMap &idxMap : idxMaps)
48306a65ce5SPeiming Liu idxMap = idxMap.compose(order); // sorted loop -> lvl map
48406a65ce5SPeiming Liu
4855fcf907bSMatthias Springer rewriter.startOpModification(linalgOp);
48606a65ce5SPeiming Liu linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
48706a65ce5SPeiming Liu linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
4885fcf907bSMatthias Springer rewriter.finalizeOpModification(linalgOp);
48906a65ce5SPeiming Liu
49006a65ce5SPeiming Liu return success();
49106a65ce5SPeiming Liu }
49206a65ce5SPeiming Liu
49306a65ce5SPeiming Liu private:
49406a65ce5SPeiming Liu /// Whether the loop order is admissible by sparsification.
isAdmissibleOrder__anonbf08e2170411::GenericOpScheduler49506a65ce5SPeiming Liu static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
49606a65ce5SPeiming Liu if (!hasAnySparseResult(linalgOp))
49706a65ce5SPeiming Liu return true;
49806a65ce5SPeiming Liu
49906a65ce5SPeiming Liu OpOperand *lhs = linalgOp.getDpsInitOperand(0);
50006a65ce5SPeiming Liu unsigned nest = 0;
50106a65ce5SPeiming Liu const auto iteratorTypes = linalgOp.getIteratorTypesArray();
50206a65ce5SPeiming Liu for (const AffineExpr l : order.getResults()) {
50306a65ce5SPeiming Liu unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
50406a65ce5SPeiming Liu auto itTp =
505a5757c5bSChristian Sigg cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
50606a65ce5SPeiming Liu if (linalg::isReductionIterator(itTp.getValue()))
50706a65ce5SPeiming Liu break; // terminate at first reduction
50806a65ce5SPeiming Liu nest++;
50906a65ce5SPeiming Liu }
51006a65ce5SPeiming Liu // Determine admissible dynamic insertion situations:
51106a65ce5SPeiming Liu // (1) fully injective, since there are no reductions,
51206a65ce5SPeiming Liu // (2) admissible 1-d expansion in innermost dimension.
51306a65ce5SPeiming Liu return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
51406a65ce5SPeiming Liu };
51506a65ce5SPeiming Liu
51606a65ce5SPeiming Liu // Last resort cycle resolution.
resolveCycle__anonbf08e2170411::GenericOpScheduler51706a65ce5SPeiming Liu static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
51806a65ce5SPeiming Liu linalg::LinalgOp linalgOp,
51906a65ce5SPeiming Liu PatternRewriter &rewriter) {
52006a65ce5SPeiming Liu // Compute topological sort while leaving out every sparse input tensor in
52106a65ce5SPeiming Liu // succession until an acylic iteration graph results.
52206a65ce5SPeiming Liu for (OpOperand *t : linalgOp.getDpsInputOperands()) {
52306a65ce5SPeiming Liu Value tval = t->get();
52406a65ce5SPeiming Liu auto srcEnc = getSparseTensorEncoding(tval.getType());
52506a65ce5SPeiming Liu // The constraints introduced by compound index expression are
52606a65ce5SPeiming Liu // complicated. Skip them.
52706a65ce5SPeiming Liu AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
52806a65ce5SPeiming Liu bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
52906a65ce5SPeiming Liu return !llvm::isa<AffineDimExpr>(exp);
53006a65ce5SPeiming Liu });
53106a65ce5SPeiming Liu if (!srcEnc || hasCompExpr)
53206a65ce5SPeiming Liu continue;
53306a65ce5SPeiming Liu
53406a65ce5SPeiming Liu // Try scheduling loop without constraints from `tval`.
53506a65ce5SPeiming Liu AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
53606a65ce5SPeiming Liu if (!order) // still cyclic
53706a65ce5SPeiming Liu continue;
53806a65ce5SPeiming Liu
53906a65ce5SPeiming Liu // Found an input tensor that resolves the cycle by inserting a
54006a65ce5SPeiming Liu // conversion into a sparse tensor that adheres to the iteration
54106a65ce5SPeiming Liu // graph order.
54206a65ce5SPeiming Liu auto stt = getSparseTensorType(tval);
54306a65ce5SPeiming Liu assert(stt.isIdentity());
54406a65ce5SPeiming Liu order = inversePermutation(order);
54506a65ce5SPeiming Liu // sorted loop -> lvl map.
54606a65ce5SPeiming Liu idxMap = idxMap.compose(order);
54706a65ce5SPeiming Liu
54806a65ce5SPeiming Liu // Found a permutation such that the results in `idxMap` is sorted.
54906a65ce5SPeiming Liu // For example,
55006a65ce5SPeiming Liu // (d0, d1, d2, d3) -> (d2, d1, d0)
55106a65ce5SPeiming Liu // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
55206a65ce5SPeiming Liu // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
55306a65ce5SPeiming Liu // transposed tensor's levels are visited in the same order as the loop
55406a65ce5SPeiming Liu // scheduling order.
55506a65ce5SPeiming Liu SmallVector<std::pair<unsigned, unsigned>> lvlSeq;
55606a65ce5SPeiming Liu for (AffineExpr expr : idxMap.getResults()) {
55706a65ce5SPeiming Liu unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
55806a65ce5SPeiming Liu lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
55906a65ce5SPeiming Liu }
560*197c3a3eSKazu Hirata llvm::sort(lvlSeq, llvm::less_first());
56106a65ce5SPeiming Liu SmallVector<unsigned> perm =
56206a65ce5SPeiming Liu llvm::to_vector(llvm::make_second_range(lvlSeq));
56306a65ce5SPeiming Liu auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
56406a65ce5SPeiming Liu // The result of the idxMap must be unsorted.
56506a65ce5SPeiming Liu assert(!dimToLvl.isIdentity());
56606a65ce5SPeiming Liu
56706a65ce5SPeiming Liu // Inserting the transpose
56806a65ce5SPeiming Liu rewriter.setInsertionPoint(linalgOp);
56906a65ce5SPeiming Liu RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
57006a65ce5SPeiming Liu Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
5715fcf907bSMatthias Springer rewriter.modifyOpInPlace(linalgOp, [&]() {
57206a65ce5SPeiming Liu linalgOp->setOperand(t->getOperandNumber(), dst);
57306a65ce5SPeiming Liu });
57442c38b1cSAart Bik
57542c38b1cSAart Bik // Release the transposed form afterwards.
57642c38b1cSAart Bik // TODO: CSE when used in more than one following op?
57742c38b1cSAart Bik rewriter.setInsertionPointAfter(linalgOp);
57842c38b1cSAart Bik rewriter.create<bufferization::DeallocTensorOp>(dst.getLoc(), dst);
57942c38b1cSAart Bik
58006a65ce5SPeiming Liu return success();
58106a65ce5SPeiming Liu }
58206a65ce5SPeiming Liu // Cannot be resolved with a single conversion.
58306a65ce5SPeiming Liu // TODO: convert more than one?
58406a65ce5SPeiming Liu return failure();
58506a65ce5SPeiming Liu }
58606a65ce5SPeiming Liu };
58706a65ce5SPeiming Liu
588e5999787SAart Bik //===----------------------------------------------------------------------===//
5893426d330SPeiming Liu // Reinterpret Map Rewriters for operations other than linalg.generics
590e5999787SAart Bik //===----------------------------------------------------------------------===//
5917cfac1beSAart Bik
592c99951d4SPeiming Liu template <typename AllocOp>
593c99951d4SPeiming Liu struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
594c99951d4SPeiming Liu using OpRewritePattern<AllocOp>::OpRewritePattern;
matchAndRewrite__anonbf08e2170411::TensorAllocDemapper595c99951d4SPeiming Liu LogicalResult matchAndRewrite(AllocOp op,
596c0d78c42SPeiming Liu PatternRewriter &rewriter) const override {
59706a65ce5SPeiming Liu if (!hasAnyNonIdentityOperandsOrResults(op))
598c0d78c42SPeiming Liu return failure();
599c0d78c42SPeiming Liu
600c0d78c42SPeiming Liu Location loc = op.getLoc();
601c0d78c42SPeiming Liu auto stt = getSparseTensorType(op.getResult());
602c0d78c42SPeiming Liu
603c0d78c42SPeiming Liu SmallVector<Value> maxDimCrds;
604c0d78c42SPeiming Liu maxDimCrds.reserve(stt.getDimRank());
605c0d78c42SPeiming Liu ValueRange dynSz = op.getDynamicSizes();
606c0d78c42SPeiming Liu for (int64_t dimSz : stt.getDimShape()) {
607c0d78c42SPeiming Liu if (ShapedType::isDynamic(dimSz)) {
608c0d78c42SPeiming Liu Value maxCrd = rewriter.create<arith::SubIOp>(
609c0d78c42SPeiming Liu loc, dynSz.front(), constantIndex(rewriter, loc, 1));
610c0d78c42SPeiming Liu maxDimCrds.push_back(maxCrd);
611c0d78c42SPeiming Liu dynSz = dynSz.drop_front();
612c0d78c42SPeiming Liu } else {
613c0d78c42SPeiming Liu maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
614c0d78c42SPeiming Liu }
615c0d78c42SPeiming Liu }
616c0d78c42SPeiming Liu
617c0d78c42SPeiming Liu ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
618c0d78c42SPeiming Liu CrdTransDirectionKind::dim2lvl);
619c0d78c42SPeiming Liu auto lvlShape = stt.getLvlShape();
620c0d78c42SPeiming Liu SmallVector<Value> dynLvlSzs;
621c0d78c42SPeiming Liu for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
622c0d78c42SPeiming Liu if (ShapedType::isDynamic(lvlShape[i])) {
623c0d78c42SPeiming Liu Value sz = rewriter.create<arith::AddIOp>(
624c0d78c42SPeiming Liu loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
625c0d78c42SPeiming Liu dynLvlSzs.push_back(sz);
626c0d78c42SPeiming Liu }
627c0d78c42SPeiming Liu }
628c0d78c42SPeiming Liu
629c0d78c42SPeiming Liu assert(dynSz.empty()); // should have consumed all.
6305fcf907bSMatthias Springer rewriter.startOpModification(op);
631c0d78c42SPeiming Liu op->setOperands(dynLvlSzs);
632c0d78c42SPeiming Liu op.getResult().setType(stt.getDemappedType());
6335fcf907bSMatthias Springer rewriter.finalizeOpModification(op);
634c0d78c42SPeiming Liu rewriter.setInsertionPointAfter(op);
635c0d78c42SPeiming Liu
636c0d78c42SPeiming Liu Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
637c0d78c42SPeiming Liu rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
638c0d78c42SPeiming Liu return success();
639c0d78c42SPeiming Liu }
640c0d78c42SPeiming Liu };
641c0d78c42SPeiming Liu
6423426d330SPeiming Liu struct TensorInsertDemapper
6433426d330SPeiming Liu : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
6443426d330SPeiming Liu using DemapInsRewriter::DemapInsRewriter;
rewriteOp__anonbf08e2170411::TensorInsertDemapper6453426d330SPeiming Liu LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
646c81a2c05SPeiming Liu PatternRewriter &rewriter) const {
64794e27c26SPeiming Liu if (!hasAnySparseResult(op) || !hasAnyNonIdentityOperandsOrResults(op))
6483426d330SPeiming Liu return failure();
6493426d330SPeiming Liu
650ef100c22SPeiming Liu Location loc = op.getLoc();
651ef100c22SPeiming Liu auto stt = getSparseTensorType(op.getResult());
652ef100c22SPeiming Liu ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
653ef100c22SPeiming Liu CrdTransDirectionKind::dim2lvl);
65494e27c26SPeiming Liu auto insertOp = rewriter.create<tensor::InsertOp>(
655c81a2c05SPeiming Liu loc, op.getScalar(), adaptor.getDest(), lvlCrd);
6563426d330SPeiming Liu
6573426d330SPeiming Liu Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
6583426d330SPeiming Liu rewriter.replaceOp(op, out);
6593426d330SPeiming Liu return success();
6603426d330SPeiming Liu }
6613426d330SPeiming Liu };
6623426d330SPeiming Liu
66307bf1ddbSPeiming Liu struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {
66407bf1ddbSPeiming Liu using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anonbf08e2170411::SparseAssembleDemapper66507bf1ddbSPeiming Liu LogicalResult matchAndRewrite(AssembleOp op,
66607bf1ddbSPeiming Liu PatternRewriter &rewriter) const override {
66707bf1ddbSPeiming Liu if (!hasAnyNonIdentityOperandsOrResults(op))
66807bf1ddbSPeiming Liu return failure();
66907bf1ddbSPeiming Liu
67007bf1ddbSPeiming Liu assert(hasAnySparseResult(op));
67107bf1ddbSPeiming Liu auto stt = getSparseTensorType(op.getResult());
67207bf1ddbSPeiming Liu rewriter.modifyOpInPlace(
67307bf1ddbSPeiming Liu op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
67407bf1ddbSPeiming Liu rewriter.setInsertionPointAfter(op);
67507bf1ddbSPeiming Liu Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
67607bf1ddbSPeiming Liu rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp());
67707bf1ddbSPeiming Liu return success();
67807bf1ddbSPeiming Liu }
67907bf1ddbSPeiming Liu };
68007bf1ddbSPeiming Liu
68107bf1ddbSPeiming Liu struct SparseDisassembleDemapper
68207bf1ddbSPeiming Liu : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
68307bf1ddbSPeiming Liu using DemapInsRewriter::DemapInsRewriter;
rewriteOp__anonbf08e2170411::SparseDisassembleDemapper68407bf1ddbSPeiming Liu LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
68507bf1ddbSPeiming Liu PatternRewriter &rewriter) const {
68607bf1ddbSPeiming Liu if (!hasAnyNonIdentityOperandsOrResults(op))
68707bf1ddbSPeiming Liu return failure();
68807bf1ddbSPeiming Liu
68907bf1ddbSPeiming Liu assert(hasAnySparseOperandOrResult(op));
69007bf1ddbSPeiming Liu rewriter.modifyOpInPlace(op, [&op, &adaptor]() {
69107bf1ddbSPeiming Liu op.getTensorMutable().assign(adaptor.getTensor());
69207bf1ddbSPeiming Liu });
69307bf1ddbSPeiming Liu return success();
69407bf1ddbSPeiming Liu }
69507bf1ddbSPeiming Liu };
69607bf1ddbSPeiming Liu
6973426d330SPeiming Liu struct ForeachOpDemapper
6983426d330SPeiming Liu : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
6993426d330SPeiming Liu using DemapInsRewriter::DemapInsRewriter;
rewriteOp__anonbf08e2170411::ForeachOpDemapper7003426d330SPeiming Liu LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
7013426d330SPeiming Liu PatternRewriter &rewriter) const {
7023426d330SPeiming Liu // Only handle operations with sparse input/output with non-identity dim2lvl
7033426d330SPeiming Liu // maps.
70406a65ce5SPeiming Liu if (!hasAnyNonIdentityOperandsOrResults(op))
7053426d330SPeiming Liu return failure();
7063426d330SPeiming Liu
7073426d330SPeiming Liu // TODO: demap constant as well.
7083426d330SPeiming Liu if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
7093426d330SPeiming Liu if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
7103426d330SPeiming Liu return failure();
7113426d330SPeiming Liu
7123426d330SPeiming Liu Location loc = op.getLoc();
7133426d330SPeiming Liu // Cache the type information since we update the foreach op in-place.
7143426d330SPeiming Liu auto srcStt = getSparseTensorType(op.getTensor());
7153426d330SPeiming Liu SmallVector<Type> prevRetTps(op.getResultTypes());
7163426d330SPeiming Liu
7175fcf907bSMatthias Springer rewriter.startOpModification(op);
7183426d330SPeiming Liu op.getTensorMutable().assign(adaptor.getTensor());
7193426d330SPeiming Liu op.getInitArgsMutable().assign(adaptor.getInitArgs());
7203426d330SPeiming Liu // Update results' types.
7213426d330SPeiming Liu for (auto r : op.getResults())
7223426d330SPeiming Liu if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
7233426d330SPeiming Liu r.setType(stt->getDemappedType());
7243426d330SPeiming Liu
7253426d330SPeiming Liu Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
7263426d330SPeiming Liu // Update the foreach body.
7273426d330SPeiming Liu SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
7283426d330SPeiming Liu blockArgTps.push_back(srcStt.getElementType());
7293426d330SPeiming Liu blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
7303426d330SPeiming Liu adaptor.getInitArgs().getTypes().end());
7313426d330SPeiming Liu Block *body = op.getBody();
7323426d330SPeiming Liu // Block Args: [dimCrd, val, initArgs]
7333426d330SPeiming Liu unsigned preArgNum = body->getNumArguments();
7343426d330SPeiming Liu for (Type t : blockArgTps)
7353426d330SPeiming Liu body->addArgument(t, loc);
7363426d330SPeiming Liu
7373426d330SPeiming Liu // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
7383426d330SPeiming Liu rewriter.setInsertionPointToStart(body);
7393426d330SPeiming Liu ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
7403426d330SPeiming Liu
7413426d330SPeiming Liu ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
7423426d330SPeiming Liu CrdTransDirectionKind::lvl2dim);
7433426d330SPeiming Liu rewriter.replaceAllUsesWith(
7443426d330SPeiming Liu body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
7453426d330SPeiming Liu body->eraseArguments(0, srcStt.getDimRank());
7463426d330SPeiming Liu // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
7473426d330SPeiming Liu unsigned numInitArgs = op.getInitArgs().size();
7483426d330SPeiming Liu rewriter.replaceAllUsesWith(body->getArgument(0),
7493426d330SPeiming Liu body->getArgument(lvlRank + numInitArgs + 1));
7503426d330SPeiming Liu body->eraseArgument(0);
7513426d330SPeiming Liu // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
7523426d330SPeiming Liu ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
7533426d330SPeiming Liu ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
7543426d330SPeiming Liu // Remap back before replacement.
7553426d330SPeiming Liu SmallVector<Value> reMappedArgs =
7563426d330SPeiming Liu remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
7573426d330SPeiming Liu rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
7583426d330SPeiming Liu body->eraseArguments(0, numInitArgs);
7593426d330SPeiming Liu // Block Args: [lvlCrds, DemappedArgs] and we are done.
7603426d330SPeiming Liu
7613426d330SPeiming Liu // Update yield operations.
7623426d330SPeiming Liu if (numInitArgs != 0) {
7633426d330SPeiming Liu rewriter.setInsertionPointToEnd(body);
7643426d330SPeiming Liu auto yield = llvm::cast<YieldOp>(body->getTerminator());
765a54930e6SPeiming Liu if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
7663426d330SPeiming Liu stt && !stt->isIdentity()) {
767a54930e6SPeiming Liu Value y =
768a54930e6SPeiming Liu genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
7693426d330SPeiming Liu rewriter.create<YieldOp>(loc, y);
7703426d330SPeiming Liu rewriter.eraseOp(yield);
7713426d330SPeiming Liu }
7723426d330SPeiming Liu }
7735fcf907bSMatthias Springer rewriter.finalizeOpModification(op);
7743426d330SPeiming Liu
7753426d330SPeiming Liu rewriter.setInsertionPointAfter(op);
7763426d330SPeiming Liu SmallVector<Value> outs =
7773426d330SPeiming Liu remapValueRange(rewriter, prevRetTps, op.getResults());
7783426d330SPeiming Liu
7793426d330SPeiming Liu // Replace all the uses of the foreach results, expect the use in
7803426d330SPeiming Liu // reinterpret_map used to remap the output.
7813426d330SPeiming Liu for (auto [from, to] : llvm::zip(op.getResults(), outs))
7823426d330SPeiming Liu rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
7833426d330SPeiming Liu
7843426d330SPeiming Liu return success();
785ef100c22SPeiming Liu }
786ef100c22SPeiming Liu };
787ef100c22SPeiming Liu
7887cfac1beSAart Bik } // namespace
7897cfac1beSAart Bik
populateSparseReinterpretMap(RewritePatternSet & patterns,ReinterpretMapScope scope)7906a93da99SPeiming Liu void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
791ef100c22SPeiming Liu ReinterpretMapScope scope) {
792ef100c22SPeiming Liu if (scope == ReinterpretMapScope::kAll ||
793e5999787SAart Bik scope == ReinterpretMapScope::kGenericOnly) {
79406a65ce5SPeiming Liu patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
79506a65ce5SPeiming Liu patterns.getContext());
796e5999787SAart Bik }
797e5999787SAart Bik if (scope == ReinterpretMapScope::kAll ||
798ef100c22SPeiming Liu scope == ReinterpretMapScope::kExceptGeneric) {
799c99951d4SPeiming Liu patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
80007bf1ddbSPeiming Liu TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
80107bf1ddbSPeiming Liu SparseDisassembleDemapper, TensorInsertDemapper,
802c99951d4SPeiming Liu ForeachOpDemapper>(patterns.getContext());
803ef100c22SPeiming Liu }
804ef100c22SPeiming Liu }
805