1c6d85bafSPeiming Liu //===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===// 2c6d85bafSPeiming Liu // 3c6d85bafSPeiming Liu // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c6d85bafSPeiming Liu // See https://llvm.org/LICENSE.txt for license information. 5c6d85bafSPeiming Liu // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c6d85bafSPeiming Liu // 7c6d85bafSPeiming Liu //===----------------------------------------------------------------------===// 8c6d85bafSPeiming Liu 9c6d85bafSPeiming Liu #include "mlir/Dialect/Func/IR/FuncOps.h" 10c6d85bafSPeiming Liu #include "mlir/IR/IRMapping.h" 11c6d85bafSPeiming Liu #include "mlir/Transforms/Passes.h" 12c6d85bafSPeiming Liu 13c6d85bafSPeiming Liu #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14c6d85bafSPeiming Liu #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 15c6d85bafSPeiming Liu 16c6d85bafSPeiming Liu namespace mlir { 17c6d85bafSPeiming Liu #define GEN_PASS_DEF_SPARSESPACECOLLAPSE 18c6d85bafSPeiming Liu #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 19c6d85bafSPeiming Liu } // namespace mlir 20c6d85bafSPeiming Liu 21c6d85bafSPeiming Liu #define DEBUG_TYPE "sparse-space-collapse" 22c6d85bafSPeiming Liu 23c6d85bafSPeiming Liu using namespace mlir; 24c6d85bafSPeiming Liu using namespace sparse_tensor; 25c6d85bafSPeiming Liu 26c6d85bafSPeiming Liu namespace { 27c6d85bafSPeiming Liu 28c6d85bafSPeiming Liu struct CollapseSpaceInfo { 29c6d85bafSPeiming Liu ExtractIterSpaceOp space; 30c6d85bafSPeiming Liu IterateOp loop; 31c6d85bafSPeiming Liu }; 32c6d85bafSPeiming Liu 33c6d85bafSPeiming Liu bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) { 34c6d85bafSPeiming Liu auto pIterArgs = parent.getRegionIterArgs(); 35c6d85bafSPeiming Liu auto nInitArgs = node.getInits(); 36c6d85bafSPeiming Liu if (pIterArgs.size() != nInitArgs.size()) 37c6d85bafSPeiming Liu return false; 38c6d85bafSPeiming Liu 39c6d85bafSPeiming Liu // Two loops are collapsable if they are perfectly nested. 40c6d85bafSPeiming Liu auto pYields = parent.getYieldedValues(); 41c6d85bafSPeiming Liu auto nResult = node.getLoopResults().value(); 42c6d85bafSPeiming Liu 43c6d85bafSPeiming Liu bool yieldEq = 44c6d85bafSPeiming Liu llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) { 45c6d85bafSPeiming Liu return std::get<0>(zipped) == std::get<1>(zipped); 46c6d85bafSPeiming Liu }); 47c6d85bafSPeiming Liu 48c6d85bafSPeiming Liu // Parent iter_args should be passed directly to the node's init_args. 49c6d85bafSPeiming Liu bool iterArgEq = 50c6d85bafSPeiming Liu llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) { 51c6d85bafSPeiming Liu return std::get<0>(zipped) == std::get<1>(zipped); 52c6d85bafSPeiming Liu }); 53c6d85bafSPeiming Liu 54c6d85bafSPeiming Liu return yieldEq && iterArgEq; 55c6d85bafSPeiming Liu } 56c6d85bafSPeiming Liu 57c6d85bafSPeiming Liu bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse, 58c6d85bafSPeiming Liu ExtractIterSpaceOp curSpace) { 59c6d85bafSPeiming Liu 60c6d85bafSPeiming Liu auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp { 61c6d85bafSPeiming Liu Value spaceVal = space.getExtractedSpace(); 62c6d85bafSPeiming Liu if (spaceVal.hasOneUse()) 63c6d85bafSPeiming Liu return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin()); 64c6d85bafSPeiming Liu return nullptr; 65c6d85bafSPeiming Liu }; 66c6d85bafSPeiming Liu 67c6d85bafSPeiming Liu if (toCollapse.empty()) { 68c6d85bafSPeiming Liu // Collapse root. 69c6d85bafSPeiming Liu if (auto itOp = getIterateOpOverSpace(curSpace)) { 70c6d85bafSPeiming Liu CollapseSpaceInfo &info = toCollapse.emplace_back(); 71c6d85bafSPeiming Liu info.space = curSpace; 72c6d85bafSPeiming Liu info.loop = itOp; 73c6d85bafSPeiming Liu return true; 74c6d85bafSPeiming Liu } 75c6d85bafSPeiming Liu return false; 76c6d85bafSPeiming Liu } 77c6d85bafSPeiming Liu 78c6d85bafSPeiming Liu auto parent = toCollapse.back().space; 79c6d85bafSPeiming Liu auto pItOp = toCollapse.back().loop; 80c6d85bafSPeiming Liu auto nItOp = getIterateOpOverSpace(curSpace); 81c6d85bafSPeiming Liu 82c6d85bafSPeiming Liu // Can only collapse spaces extracted from the same tensor. 83c6d85bafSPeiming Liu if (parent.getTensor() != curSpace.getTensor()) { 84c6d85bafSPeiming Liu LLVM_DEBUG({ 85c6d85bafSPeiming Liu llvm::dbgs() 86c6d85bafSPeiming Liu << "failed to collpase spaces extracted from different tensors."; 87c6d85bafSPeiming Liu }); 88c6d85bafSPeiming Liu return false; 89c6d85bafSPeiming Liu } 90c6d85bafSPeiming Liu 91c6d85bafSPeiming Liu // Can only collapse consecutive simple iteration on one tensor (i.e., no 92c6d85bafSPeiming Liu // coiteration). 93c6d85bafSPeiming Liu if (!nItOp || nItOp->getBlock() != curSpace->getBlock() || 94c6d85bafSPeiming Liu pItOp.getIterator() != curSpace.getParentIter() || 95c6d85bafSPeiming Liu curSpace->getParentOp() != pItOp.getOperation()) { 96c6d85bafSPeiming Liu LLVM_DEBUG( 97c6d85bafSPeiming Liu { llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; }); 98c6d85bafSPeiming Liu return false; 99c6d85bafSPeiming Liu } 100c6d85bafSPeiming Liu 101c6d85bafSPeiming Liu if (pItOp && !isCollapsableLoops(pItOp, nItOp)) { 102c6d85bafSPeiming Liu LLVM_DEBUG({ 103c6d85bafSPeiming Liu llvm::dbgs() 104c6d85bafSPeiming Liu << "failed to collapse IterateOps that are not perfectly nested."; 105c6d85bafSPeiming Liu }); 106c6d85bafSPeiming Liu return false; 107c6d85bafSPeiming Liu } 108c6d85bafSPeiming Liu 109c6d85bafSPeiming Liu CollapseSpaceInfo &info = toCollapse.emplace_back(); 110c6d85bafSPeiming Liu info.space = curSpace; 111c6d85bafSPeiming Liu info.loop = nItOp; 112c6d85bafSPeiming Liu return true; 113c6d85bafSPeiming Liu } 114c6d85bafSPeiming Liu 115c6d85bafSPeiming Liu void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) { 116c6d85bafSPeiming Liu if (toCollapse.size() < 2) 117c6d85bafSPeiming Liu return; 118c6d85bafSPeiming Liu 119c6d85bafSPeiming Liu ExtractIterSpaceOp root = toCollapse.front().space; 120c6d85bafSPeiming Liu ExtractIterSpaceOp leaf = toCollapse.back().space; 121c6d85bafSPeiming Liu Location loc = root.getLoc(); 122c6d85bafSPeiming Liu 123c6d85bafSPeiming Liu assert(root->hasOneUse() && leaf->hasOneUse()); 124c6d85bafSPeiming Liu 125c6d85bafSPeiming Liu // Insert collapsed operation at the same scope as root operation. 126c6d85bafSPeiming Liu OpBuilder builder(root); 127c6d85bafSPeiming Liu 128c6d85bafSPeiming Liu // Construct the collapsed iteration space. 129c6d85bafSPeiming Liu auto collapsedSpace = builder.create<ExtractIterSpaceOp>( 130c6d85bafSPeiming Liu loc, root.getTensor(), root.getParentIter(), root.getLoLvl(), 131c6d85bafSPeiming Liu leaf.getHiLvl()); 132c6d85bafSPeiming Liu 133c6d85bafSPeiming Liu auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin()); 134c6d85bafSPeiming Liu auto innermost = toCollapse.back().loop; 135c6d85bafSPeiming Liu 136c6d85bafSPeiming Liu IRMapping mapper; 137c6d85bafSPeiming Liu mapper.map(leaf, collapsedSpace.getExtractedSpace()); 138c6d85bafSPeiming Liu for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs())) 139c6d85bafSPeiming Liu mapper.map(std::get<0>(z), std::get<1>(z)); 140c6d85bafSPeiming Liu 141c6d85bafSPeiming Liu auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper)); 142c6d85bafSPeiming Liu builder.setInsertionPointToStart(cloned.getBody()); 143c6d85bafSPeiming Liu 144*785a24f1SPeiming Liu I64BitSet crdUsedLvls; 145c6d85bafSPeiming Liu unsigned shift = 0, argIdx = 1; 146c6d85bafSPeiming Liu for (auto info : toCollapse.drop_back()) { 147*785a24f1SPeiming Liu I64BitSet set = info.loop.getCrdUsedLvls(); 148c6d85bafSPeiming Liu crdUsedLvls |= set.lshift(shift); 149c6d85bafSPeiming Liu shift += info.loop.getSpaceDim(); 150c6d85bafSPeiming Liu for (BlockArgument crd : info.loop.getCrds()) { 151c6d85bafSPeiming Liu BlockArgument collapsedCrd = cloned.getBody()->insertArgument( 152c6d85bafSPeiming Liu argIdx++, builder.getIndexType(), crd.getLoc()); 153c6d85bafSPeiming Liu crd.replaceAllUsesWith(collapsedCrd); 154c6d85bafSPeiming Liu } 155c6d85bafSPeiming Liu } 156c6d85bafSPeiming Liu crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift); 157c6d85bafSPeiming Liu cloned.getIterator().setType(collapsedSpace.getType().getIteratorType()); 158c6d85bafSPeiming Liu cloned.setCrdUsedLvls(crdUsedLvls); 159c6d85bafSPeiming Liu 160c6d85bafSPeiming Liu rItOp.replaceAllUsesWith(cloned.getResults()); 161c6d85bafSPeiming Liu // Erase collapsed loops. 162c6d85bafSPeiming Liu rItOp.erase(); 163c6d85bafSPeiming Liu root.erase(); 164c6d85bafSPeiming Liu } 165c6d85bafSPeiming Liu 166c6d85bafSPeiming Liu struct SparseSpaceCollapsePass 167c6d85bafSPeiming Liu : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> { 168c6d85bafSPeiming Liu SparseSpaceCollapsePass() = default; 169c6d85bafSPeiming Liu 170c6d85bafSPeiming Liu void runOnOperation() override { 171c6d85bafSPeiming Liu func::FuncOp func = getOperation(); 172c6d85bafSPeiming Liu 173c6d85bafSPeiming Liu // A naive (experimental) implementation to collapse consecutive sparse 174c6d85bafSPeiming Liu // spaces. It does NOT handle complex cases where multiple spaces are 175c6d85bafSPeiming Liu // extracted in the same basic block. E.g., 176c6d85bafSPeiming Liu // 177c6d85bafSPeiming Liu // %space1 = extract_space %t1 ... 178c6d85bafSPeiming Liu // %space2 = extract_space %t2 ... 179c6d85bafSPeiming Liu // sparse_tensor.iterate(%sp1) ... 180c6d85bafSPeiming Liu // 181c6d85bafSPeiming Liu SmallVector<CollapseSpaceInfo> toCollapse; 182c6d85bafSPeiming Liu func->walk([&](ExtractIterSpaceOp op) { 183c6d85bafSPeiming Liu if (!legalToCollapse(toCollapse, op)) { 184c6d85bafSPeiming Liu // if not legal to collapse one more space, collapse the existing ones 185c6d85bafSPeiming Liu // and clear. 186c6d85bafSPeiming Liu collapseSparseSpace(toCollapse); 187c6d85bafSPeiming Liu toCollapse.clear(); 188c6d85bafSPeiming Liu } 189c6d85bafSPeiming Liu }); 190c6d85bafSPeiming Liu 191c6d85bafSPeiming Liu collapseSparseSpace(toCollapse); 192c6d85bafSPeiming Liu } 193c6d85bafSPeiming Liu }; 194c6d85bafSPeiming Liu 195c6d85bafSPeiming Liu } // namespace 196c6d85bafSPeiming Liu 197c6d85bafSPeiming Liu std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() { 198c6d85bafSPeiming Liu return std::make_unique<SparseSpaceCollapsePass>(); 199c6d85bafSPeiming Liu } 200