//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/Transforms/Passes.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" namespace mlir { #define GEN_PASS_DEF_SPARSESPACECOLLAPSE #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" } // namespace mlir #define DEBUG_TYPE "sparse-space-collapse" using namespace mlir; using namespace sparse_tensor; namespace { struct CollapseSpaceInfo { ExtractIterSpaceOp space; IterateOp loop; }; bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) { auto pIterArgs = parent.getRegionIterArgs(); auto nInitArgs = node.getInits(); if (pIterArgs.size() != nInitArgs.size()) return false; // Two loops are collapsable if they are perfectly nested. auto pYields = parent.getYieldedValues(); auto nResult = node.getLoopResults().value(); bool yieldEq = llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) { return std::get<0>(zipped) == std::get<1>(zipped); }); // Parent iter_args should be passed directly to the node's init_args. bool iterArgEq = llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) { return std::get<0>(zipped) == std::get<1>(zipped); }); return yieldEq && iterArgEq; } bool legalToCollapse(SmallVectorImpl &toCollapse, ExtractIterSpaceOp curSpace) { auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp { Value spaceVal = space.getExtractedSpace(); if (spaceVal.hasOneUse()) return llvm::dyn_cast(*spaceVal.getUsers().begin()); return nullptr; }; if (toCollapse.empty()) { // Collapse root. if (auto itOp = getIterateOpOverSpace(curSpace)) { CollapseSpaceInfo &info = toCollapse.emplace_back(); info.space = curSpace; info.loop = itOp; return true; } return false; } auto parent = toCollapse.back().space; auto pItOp = toCollapse.back().loop; auto nItOp = getIterateOpOverSpace(curSpace); // Can only collapse spaces extracted from the same tensor. if (parent.getTensor() != curSpace.getTensor()) { LLVM_DEBUG({ llvm::dbgs() << "failed to collpase spaces extracted from different tensors."; }); return false; } // Can only collapse consecutive simple iteration on one tensor (i.e., no // coiteration). if (!nItOp || nItOp->getBlock() != curSpace->getBlock() || pItOp.getIterator() != curSpace.getParentIter() || curSpace->getParentOp() != pItOp.getOperation()) { LLVM_DEBUG( { llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; }); return false; } if (pItOp && !isCollapsableLoops(pItOp, nItOp)) { LLVM_DEBUG({ llvm::dbgs() << "failed to collapse IterateOps that are not perfectly nested."; }); return false; } CollapseSpaceInfo &info = toCollapse.emplace_back(); info.space = curSpace; info.loop = nItOp; return true; } void collapseSparseSpace(MutableArrayRef toCollapse) { if (toCollapse.size() < 2) return; ExtractIterSpaceOp root = toCollapse.front().space; ExtractIterSpaceOp leaf = toCollapse.back().space; Location loc = root.getLoc(); assert(root->hasOneUse() && leaf->hasOneUse()); // Insert collapsed operation at the same scope as root operation. OpBuilder builder(root); // Construct the collapsed iteration space. auto collapsedSpace = builder.create( loc, root.getTensor(), root.getParentIter(), root.getLoLvl(), leaf.getHiLvl()); auto rItOp = llvm::cast(*root->getUsers().begin()); auto innermost = toCollapse.back().loop; IRMapping mapper; mapper.map(leaf, collapsedSpace.getExtractedSpace()); for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs())) mapper.map(std::get<0>(z), std::get<1>(z)); auto cloned = llvm::cast(builder.clone(*innermost, mapper)); builder.setInsertionPointToStart(cloned.getBody()); I64BitSet crdUsedLvls; unsigned shift = 0, argIdx = 1; for (auto info : toCollapse.drop_back()) { I64BitSet set = info.loop.getCrdUsedLvls(); crdUsedLvls |= set.lshift(shift); shift += info.loop.getSpaceDim(); for (BlockArgument crd : info.loop.getCrds()) { BlockArgument collapsedCrd = cloned.getBody()->insertArgument( argIdx++, builder.getIndexType(), crd.getLoc()); crd.replaceAllUsesWith(collapsedCrd); } } crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift); cloned.getIterator().setType(collapsedSpace.getType().getIteratorType()); cloned.setCrdUsedLvls(crdUsedLvls); rItOp.replaceAllUsesWith(cloned.getResults()); // Erase collapsed loops. rItOp.erase(); root.erase(); } struct SparseSpaceCollapsePass : public impl::SparseSpaceCollapseBase { SparseSpaceCollapsePass() = default; void runOnOperation() override { func::FuncOp func = getOperation(); // A naive (experimental) implementation to collapse consecutive sparse // spaces. It does NOT handle complex cases where multiple spaces are // extracted in the same basic block. E.g., // // %space1 = extract_space %t1 ... // %space2 = extract_space %t2 ... // sparse_tensor.iterate(%sp1) ... // SmallVector toCollapse; func->walk([&](ExtractIterSpaceOp op) { if (!legalToCollapse(toCollapse, op)) { // if not legal to collapse one more space, collapse the existing ones // and clear. collapseSparseSpace(toCollapse); toCollapse.clear(); } }); collapseSparseSpace(toCollapse); } }; } // namespace std::unique_ptr mlir::createSparseSpaceCollapsePass() { return std::make_unique(); }