//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/ // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "Utils/CodegenUtils.h" #include "Utils/IterationGraphSorter.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" using namespace mlir; using namespace mlir::sparse_tensor; namespace { //===----------------------------------------------------------------------===// // File Local Helper classes. //===----------------------------------------------------------------------===// // CRTP to help implementing a rewriter that demaps all its inputs. template struct DemapInsRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; using OpAdaptor = typename SourceOp::Adaptor; LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); // Demaps non-trivial inputs. bool changed = false; SmallVector deMappedIns(op->getOperands()); for (Value &in : deMappedIns) { if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) { in = rewriter.create(loc, stt->getDemappedType(), in); changed = true; } } // CRTP call. OpAdaptor adaptor(deMappedIns, op); LogicalResult status = static_cast(this)->rewriteOp(op, adaptor, rewriter); return changed ? success() : status; } }; // Flattens an affine expression into a list of AffineDimExprs. struct AffineDimCollector : public AffineExprVisitor { explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){}; void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); } BitVector dims; }; // Flattens an affine expression into a list of AffineDimExprs. struct AffineExprAdmissibleVisitor : public AffineExprVisitor { explicit AffineExprAdmissibleVisitor(bool isOutput) : admissible(true), isOutput(isOutput){}; // We only allow AffineDimExpr on output. void visitAddExpr(AffineBinaryOpExpr expr) { if (isOutput) admissible = false; } void visitMulExpr(AffineBinaryOpExpr expr) { if (isOutput) admissible = false; } // We disallow mod, floor div and ceil div on inputs. void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; } void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; } void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; } operator bool() { return admissible; } private: bool admissible; bool isOutput; }; // The first BitVector stores levels where inadmissible exprs are used. // The second BitVector stores the AffineDimExp that are used by the // inadmissible expressions. using InadmissInfo = std::pair; } // namespace //===----------------------------------------------------------------------===// // File Local Helper methods. //===----------------------------------------------------------------------===// // Collects the inadmissible affine expression imposed on levels. static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) { auto ret = std::make_pair(BitVector(map.getNumResults()), BitVector(map.getNumDims())); AffineDimCollector collector(map.getNumDims()); for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) { AffineExprAdmissibleVisitor admissible(isOutput); admissible.walkPostOrder(map.getResult(lvl)); if (!admissible) { // Record the inadmissible level. ret.first.set(lvl); // Record the AffineDimExpr that is used in the inadmissible expr. collector.walkPostOrder(map.getResult(lvl)); } } ret.second = collector.dims; return ret; } // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht // inadmissible affine expressions can be eliminated. // For example, we can rewrite // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3) // to // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3) // by composing inverse(idxMap), that is // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3) // -> ((l0 * 2 + l2) floordiv 2, // (l1 * 3 + l3) floordiv 3, // (l0 * 2 + l2) mod 2, // (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3) // // This function builds the inverse(idxMap) that replace every dimensions used // in `info` to levels, and updates the iterator type array `itTps` for the new // index variable introduced. // // Note that the returned affine map does not retain the order of the input // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the // replaced levels, and remaining ones for unused dimensions. // For example, to handle // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4) // which is a typical map for block_2to4. The function returns: // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1) // in which, (l0, l1) together replaces `d1`, yet they appear // before `d0` in the resulting affine map. // The index (loop) order can later be canonicalized by a topo sort. static AffineMap genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, SmallVector &itTps) { MLIRContext *ctx = idxMap.getContext(); auto [inAdLvls, usedDims] = info; // Note that idxMap does not equal to dim2Lvl map, it is computed by // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an // ID map. // TODO: we might fail here, in those case we should really return // failure instead of assertion error. auto lvl2Idx = inferLvlToDim(idxMap, ctx); assert(lvl2Idx.getNumResults() <= idxMap.getNumDims()); if (lvl2Idx.getNumResults() != idxMap.getNumDims()) { // This could happen when some dimensions are projected. // E.g., idx2Lvl = (*i*, j, k) -> (j, k) // ==> lvl2Idx = (j, k) -> (j, k) // In this case, we append the unused dimesion at the end. // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k) SmallVector results; AffineDimCollector usedInLvl(idxMap.getNumDims()); for (auto e : idxMap.getResults()) usedInLvl.walkPostOrder(e); unsigned curUsedDimID = 0; unsigned curUnusedDimID = lvl2Idx.getNumDims(); BitVector unused = usedInLvl.dims.flip(); for (unsigned i = 0; i < idxMap.getNumDims(); i++) { if (unused.test(i)) results.push_back(getAffineDimExpr(curUnusedDimID++, ctx)); else results.push_back(lvl2Idx.getResult(curUsedDimID++)); } lvl2Idx = AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx); } assert(lvl2Idx.getNumResults() == idxMap.getNumDims()); // We do not need to replace the DimExpr that is not used in inadmissible // level expressions. We use the first inAdLvl.count() dim to represent the // replaced level, the remainings are reserved for unchanged ones. // Note that results from the inverse map computed previously does not follow // the convention we used, and we need to fix the mismatch below. unsigned curRepID = 0; unsigned curOriID = inAdLvls.count(); SmallVector results; SmallVector dimRep(idxMap.getNumResults(), AffineExpr()); SmallVector transItTps; for (unsigned l : inAdLvls.set_bits()) { // By our convention, the inadmissible level `l` always appears in the // leading part (accumulated by curRepID) of the affine map's parameter // list. Record the mapping so that we can replace all the uses of `l` to // the correct position after the translation. dimRep[l] = getAffineDimExpr(curRepID++, ctx); // A new index variable is introduced for the inadmissible level, inherit // the iterator type. E.g., if l0 = d0 floordiv 2, the // iterator type of l0 equals to the iterator type of d0. AffineExpr lvlExp = idxMap.getResult(l); AffineDimCollector collector(idxMap.getNumDims()); collector.walkPostOrder(lvlExp); // We assumes a level can only be derived from one dimension. assert(collector.dims.count() == 1); transItTps.push_back(itTps[collector.dims.find_first()]); } for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) { if (usedDims.test(d)) { // The dimension is used in some of the inadmissible levels, and it need // to be inversed. Get the inversion from the inverse map, and fix the // mismatch captured by the above loop. results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep)); } else { // The dimension is not used in any of the inadmissible levels, and it // does not need to be inversed. Fix the mismatch by mapping it to the // trailing part of the affine map (accumulated by curOriID). results.push_back(getAffineDimExpr(curOriID++, ctx)); transItTps.push_back(itTps[d]); } } unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count(); // Update iterator type. itTps.assign(transItTps.begin(), transItTps.end()); return AffineMap::get(numDim, 0, results, ctx); } // Translates the index map in the linalg::GenericOp from idx->dim map to // idx->lvl map. Returns failure if the index map can not be translated to an // admissible form. // Returns the translated index map array and the iterator type array. static std::optional> translateMap(linalg::GenericOp op, PatternRewriter &rewriter) { // idxMap is a idx2dim map before reinterpretation. MLIRContext *ctx = op.getContext(); SmallVector idxMapArray = op.getIndexingMapsArray(); SmallVector itTps = op.getIteratorTypesArray(); for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) { Value tensor = op->getOpOperand(i).get(); auto stt = tryGetSparseTensorType(tensor); if (stt && !stt->isIdentity()) { AffineMap dim2Lvl = stt->getDimToLvl(); // By composing the idx2dim(dim2lvl), we got a idx2lvl Map idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]); } } // A naive way to handle common constant expressions that arise during dim2lvl // translation. auto populateCstMapping = [ctx](DenseMap &cstMapping, unsigned pos, int64_t lvlSz) { if (!ShapedType::isDynamic(lvlSz)) { auto c0 = getAffineConstantExpr(0, ctx); auto lvlExp = getAffineDimExpr(pos, ctx); auto szExp = getAffineConstantExpr(lvlSz, ctx); // lvl floordiv lvlSz = 0 auto divExp = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp); cstMapping.try_emplace(divExp, c0); // lvl mod lvlSz = lvl auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp); cstMapping.try_emplace(modExp, lvlExp); } }; unsigned boundedNum = 0; // A fixed-point algorithm. bool changed = true; while (changed) { changed = false; for (OpOperand &operand : op->getOpOperands()) { auto stt = tryGetSparseTensorType(operand.get()); // Skip on dense operands. if (!stt || !stt->getEncoding()) continue; unsigned tid = operand.getOperandNumber(); bool isOutput = &operand == op.getDpsInitOperand(0); AffineMap idxMap = idxMapArray[tid]; InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput); auto [inAdLvls, dimExprs] = inAdInfo; for (unsigned d : dimExprs.set_bits()) { // The first `boundedNum` used in the AffineMap is introduced to // resolve previous inadmissible expressions. We can not replace them // as it might bring back the inadmissible expressions. if (d < boundedNum) return std::nullopt; } if (inAdLvls.count() != 0) { // Naive constant progagation, should be sufficient to handle block // sparsity in our cases. SmallVector lvlShape = stt->getLvlShape(); DenseMap cstMapping; unsigned position = 0; for (unsigned lvl : inAdLvls.set_bits()) { int64_t lvlSz = lvlShape[lvl]; populateCstMapping(cstMapping, position, lvlSz); position++; } AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps); // Compose the lvl2Idx Map to all AffineIdxMap to eliminate // inadmissible expressions. for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) { AffineMap transMap = idxMapArray[tid].compose(lvl2Idx); idxMapArray[tid] = transMap.replace( cstMapping, /*numResultDims=*/transMap.getNumDims(), /*numResultSyms=*/0); } changed = true; boundedNum += inAdLvls.count(); } } }; SmallVector iterAttr = llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute { return linalg::IteratorTypeAttr::get(ctx, itTp); }); return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray), rewriter.getArrayAttr(iterAttr)); } // Generates a "de"mapping reinterpretation of the map. static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val) { return builder.create(val.getLoc(), enc.withoutDimToLvl(), val); } // Generates a "re"mapping reinterpretation of the map. static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val) { return builder.create(val.getLoc(), enc, val); } static SmallVector remapValueRange(OpBuilder &rewriter, TypeRange types, ValueRange outs) { SmallVector ret(outs); assert(outs.size() == types.size()); for (auto [r, t] : llvm::zip(ret, types)) if (r.getType() != t) r = rewriter.create(r.getLoc(), t, r); return ret; } namespace { //===----------------------------------------------------------------------===// // Rewriting rules for linalg generic ops. //===----------------------------------------------------------------------===// /// Sparse rewriting rule for the generic `linalg` operation. struct GenericOpReinterpretMap : public DemapInsRewriter { public: using DemapInsRewriter::DemapInsRewriter; LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor, PatternRewriter &rewriter) const { // Only rewrite single output operations with pure (sparse) tensor // semantics. if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() || !hasAnySparseOperandOrResult(linalgOp) || !hasAnyNonIdentityOperandsOrResults(linalgOp)) return failure(); // Try translating the index map. auto transMap = translateMap(linalgOp, rewriter); if (!transMap) return rewriter.notifyMatchFailure( linalgOp, "the sparse kernel can not be sparsified."); // On success, replace update the linalg operands and maps in place. Value res = linalgOp.getResult(0); auto stt = tryGetSparseTensorType(res); auto [idxMap, itTp] = *transMap; rewriter.startOpModification(linalgOp); linalgOp.setIndexingMapsAttr(idxMap); linalgOp.setIteratorTypesAttr(itTp); // Use demapped arguments. linalgOp.getInputsMutable().assign(adaptor.getInputs()); linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs()); res.setType(adaptor.getOutputs()[0].getType()); rewriter.finalizeOpModification(linalgOp); rewriter.setInsertionPointAfter(linalgOp); if (stt && stt->hasEncoding()) { Value t = genRemap(rewriter, stt->getEncoding(), res); rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp()); } return success(); } }; struct GenericOpScheduler : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::GenericOp linalgOp, PatternRewriter &rewriter) const override { if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() || hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first !hasAnySparseOperandOrResult(linalgOp)) { return failure(); } const StringRef sorted = "sorted"; if (linalgOp->hasAttr(sorted)) return failure(); auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp); bool isAdmissible = false; AffineMap order; // A const list of all masks that we used for iteration graph // computation. Must be ordered from more strict to less strict. // Ideally (though might not be guaranteed), the earlier a constraint mask // can be satisfied, the faster the generated kernel will be. const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense, SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput, SortMask::kSparseOnly}; for (const SortMask mask : allMasks) { order = scheduler.sort(mask); if (order) { if (isAdmissibleOrder(linalgOp, order)) { isAdmissible = true; break; } // else try a set of less strict constraints. } } if (!order) { // Cycles detected. if (failed(resolveCycle(scheduler, linalgOp, rewriter))) { return rewriter.notifyMatchFailure( linalgOp, "the sparse kernel can not be scheduled: loop detected."); } return success(); } if (!isAdmissible) { return rewriter.notifyMatchFailure( linalgOp, "the sparse kernel can not be scheduled."); } // Marks the GenericOp to avoid recursive matching. rewriter.modifyOpInPlace(linalgOp, [&]() { linalgOp->setAttr(sorted, rewriter.getBoolAttr(true)); }); // Already sorted. if (order.isIdentity()) return success(); assert(order.isPermutation()); // `order` is orignial loop -> sorted loop map ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr(); SmallVector curItTypes; curItTypes.reserve(preItTypes.size()); for (AffineExpr expr : order.getResults()) { unsigned loopID = llvm::cast(expr).getPosition(); curItTypes.push_back(preItTypes[loopID]); } // Inverse `order` to get sorted loop -> original loop map order = inversePermutation(order); SmallVector idxMaps = linalgOp.getIndexingMapsArray(); for (AffineMap &idxMap : idxMaps) idxMap = idxMap.compose(order); // sorted loop -> lvl map rewriter.startOpModification(linalgOp); linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps)); linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes)); rewriter.finalizeOpModification(linalgOp); return success(); } private: /// Whether the loop order is admissible by sparsification. static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) { if (!hasAnySparseResult(linalgOp)) return true; OpOperand *lhs = linalgOp.getDpsInitOperand(0); unsigned nest = 0; const auto iteratorTypes = linalgOp.getIteratorTypesArray(); for (const AffineExpr l : order.getResults()) { unsigned loopId = llvm::cast(l).getPosition(); auto itTp = cast(linalgOp.getIteratorTypes()[loopId]); if (linalg::isReductionIterator(itTp.getValue())) break; // terminate at first reduction nest++; } // Determine admissible dynamic insertion situations: // (1) fully injective, since there are no reductions, // (2) admissible 1-d expansion in innermost dimension. return static_cast(nest) >= linalgOp.getRank(lhs) - 1; }; // Last resort cycle resolution. static LogicalResult resolveCycle(IterationGraphSorter &scheduler, linalg::LinalgOp linalgOp, PatternRewriter &rewriter) { // Compute topological sort while leaving out every sparse input tensor in // succession until an acylic iteration graph results. for (OpOperand *t : linalgOp.getDpsInputOperands()) { Value tval = t->get(); auto srcEnc = getSparseTensorEncoding(tval.getType()); // The constraints introduced by compound index expression are // complicated. Skip them. AffineMap idxMap = linalgOp.getMatchingIndexingMap(t); bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) { return !llvm::isa(exp); }); if (!srcEnc || hasCompExpr) continue; // Try scheduling loop without constraints from `tval`. AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval); if (!order) // still cyclic continue; // Found an input tensor that resolves the cycle by inserting a // conversion into a sparse tensor that adheres to the iteration // graph order. auto stt = getSparseTensorType(tval); assert(stt.isIdentity()); order = inversePermutation(order); // sorted loop -> lvl map. idxMap = idxMap.compose(order); // Found a permutation such that the results in `idxMap` is sorted. // For example, // (d0, d1, d2, d3) -> (d2, d1, d0) // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle, // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the // transposed tensor's levels are visited in the same order as the loop // scheduling order. SmallVector> lvlSeq; for (AffineExpr expr : idxMap.getResults()) { unsigned lvl = llvm::cast(expr).getPosition(); lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size())); } llvm::sort(lvlSeq, llvm::less_first()); SmallVector perm = llvm::to_vector(llvm::make_second_range(lvlSeq)); auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext()); // The result of the idxMap must be unsorted. assert(!dimToLvl.isIdentity()); // Inserting the transpose rewriter.setInsertionPoint(linalgOp); RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType(); Value dst = rewriter.create(tval.getLoc(), dstTp, tval); rewriter.modifyOpInPlace(linalgOp, [&]() { linalgOp->setOperand(t->getOperandNumber(), dst); }); // Release the transposed form afterwards. // TODO: CSE when used in more than one following op? rewriter.setInsertionPointAfter(linalgOp); rewriter.create(dst.getLoc(), dst); return success(); } // Cannot be resolved with a single conversion. // TODO: convert more than one? return failure(); } }; //===----------------------------------------------------------------------===// // Reinterpret Map Rewriters for operations other than linalg.generics //===----------------------------------------------------------------------===// template struct TensorAllocDemapper : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AllocOp op, PatternRewriter &rewriter) const override { if (!hasAnyNonIdentityOperandsOrResults(op)) return failure(); Location loc = op.getLoc(); auto stt = getSparseTensorType(op.getResult()); SmallVector maxDimCrds; maxDimCrds.reserve(stt.getDimRank()); ValueRange dynSz = op.getDynamicSizes(); for (int64_t dimSz : stt.getDimShape()) { if (ShapedType::isDynamic(dimSz)) { Value maxCrd = rewriter.create( loc, dynSz.front(), constantIndex(rewriter, loc, 1)); maxDimCrds.push_back(maxCrd); dynSz = dynSz.drop_front(); } else { maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1)); } } ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds, CrdTransDirectionKind::dim2lvl); auto lvlShape = stt.getLvlShape(); SmallVector dynLvlSzs; for (unsigned i = 0, e = lvlShape.size(); i < e; i++) { if (ShapedType::isDynamic(lvlShape[i])) { Value sz = rewriter.create( loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1)); dynLvlSzs.push_back(sz); } } assert(dynSz.empty()); // should have consumed all. rewriter.startOpModification(op); op->setOperands(dynLvlSzs); op.getResult().setType(stt.getDemappedType()); rewriter.finalizeOpModification(op); rewriter.setInsertionPointAfter(op); Value t = genRemap(rewriter, stt.getEncoding(), op.getResult()); rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp()); return success(); } }; struct TensorInsertDemapper : public DemapInsRewriter { using DemapInsRewriter::DemapInsRewriter; LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor, PatternRewriter &rewriter) const { if (!hasAnySparseResult(op) || !hasAnyNonIdentityOperandsOrResults(op)) return failure(); Location loc = op.getLoc(); auto stt = getSparseTensorType(op.getResult()); ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(), CrdTransDirectionKind::dim2lvl); auto insertOp = rewriter.create( loc, op.getScalar(), adaptor.getDest(), lvlCrd); Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult()); rewriter.replaceOp(op, out); return success(); } }; struct SparseAssembleDemapper : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AssembleOp op, PatternRewriter &rewriter) const override { if (!hasAnyNonIdentityOperandsOrResults(op)) return failure(); assert(hasAnySparseResult(op)); auto stt = getSparseTensorType(op.getResult()); rewriter.modifyOpInPlace( op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); }); rewriter.setInsertionPointAfter(op); Value out = genRemap(rewriter, stt.getEncoding(), op.getResult()); rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp()); return success(); } }; struct SparseDisassembleDemapper : public DemapInsRewriter { using DemapInsRewriter::DemapInsRewriter; LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor, PatternRewriter &rewriter) const { if (!hasAnyNonIdentityOperandsOrResults(op)) return failure(); assert(hasAnySparseOperandOrResult(op)); rewriter.modifyOpInPlace(op, [&op, &adaptor]() { op.getTensorMutable().assign(adaptor.getTensor()); }); return success(); } }; struct ForeachOpDemapper : public DemapInsRewriter { using DemapInsRewriter::DemapInsRewriter; LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor, PatternRewriter &rewriter) const { // Only handle operations with sparse input/output with non-identity dim2lvl // maps. if (!hasAnyNonIdentityOperandsOrResults(op)) return failure(); // TODO: demap constant as well. if (auto constOp = op.getTensor().getDefiningOp()) if (auto attr = dyn_cast(constOp.getValue())) return failure(); Location loc = op.getLoc(); // Cache the type information since we update the foreach op in-place. auto srcStt = getSparseTensorType(op.getTensor()); SmallVector prevRetTps(op.getResultTypes()); rewriter.startOpModification(op); op.getTensorMutable().assign(adaptor.getTensor()); op.getInitArgsMutable().assign(adaptor.getInitArgs()); // Update results' types. for (auto r : op.getResults()) if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity()) r.setType(stt->getDemappedType()); Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank(); // Update the foreach body. SmallVector blockArgTps(lvlRank, rewriter.getIndexType()); blockArgTps.push_back(srcStt.getElementType()); blockArgTps.append(adaptor.getInitArgs().getTypes().begin(), adaptor.getInitArgs().getTypes().end()); Block *body = op.getBody(); // Block Args: [dimCrd, val, initArgs] unsigned preArgNum = body->getNumArguments(); for (Type t : blockArgTps) body->addArgument(t, loc); // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs] rewriter.setInsertionPointToStart(body); ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank); ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds, CrdTransDirectionKind::lvl2dim); rewriter.replaceAllUsesWith( body->getArguments().take_front(srcStt.getDimRank()), dimCrds); body->eraseArguments(0, srcStt.getDimRank()); // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs] unsigned numInitArgs = op.getInitArgs().size(); rewriter.replaceAllUsesWith(body->getArgument(0), body->getArgument(lvlRank + numInitArgs + 1)); body->eraseArgument(0); // Block Args: [initArgs, lvlCrds, val, DemappedArgs] ValueRange srcArgs = body->getArguments().take_front(numInitArgs); ValueRange dstArgs = body->getArguments().take_back(numInitArgs); // Remap back before replacement. SmallVector reMappedArgs = remapValueRange(rewriter, srcArgs.getTypes(), dstArgs); rewriter.replaceAllUsesWith(srcArgs, reMappedArgs); body->eraseArguments(0, numInitArgs); // Block Args: [lvlCrds, DemappedArgs] and we are done. // Update yield operations. if (numInitArgs != 0) { rewriter.setInsertionPointToEnd(body); auto yield = llvm::cast(body->getTerminator()); if (auto stt = tryGetSparseTensorType(yield.getSingleResult()); stt && !stt->isIdentity()) { Value y = genDemap(rewriter, stt->getEncoding(), yield.getSingleResult()); rewriter.create(loc, y); rewriter.eraseOp(yield); } } rewriter.finalizeOpModification(op); rewriter.setInsertionPointAfter(op); SmallVector outs = remapValueRange(rewriter, prevRetTps, op.getResults()); // Replace all the uses of the foreach results, expect the use in // reinterpret_map used to remap the output. for (auto [from, to] : llvm::zip(op.getResults(), outs)) rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp()); return success(); } }; } // namespace void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope) { if (scope == ReinterpretMapScope::kAll || scope == ReinterpretMapScope::kGenericOnly) { patterns.add( patterns.getContext()); } if (scope == ReinterpretMapScope::kAll || scope == ReinterpretMapScope::kExceptGeneric) { patterns.add, TensorAllocDemapper, SparseAssembleDemapper, SparseDisassembleDemapper, TensorInsertDemapper, ForeachOpDemapper>(patterns.getContext()); } }