1 //===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Func/IR/FuncOps.h" 10 #include "mlir/IR/IRMapping.h" 11 #include "mlir/Transforms/Passes.h" 12 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 15 16 namespace mlir { 17 #define GEN_PASS_DEF_SPARSESPACECOLLAPSE 18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 19 } // namespace mlir 20 21 #define DEBUG_TYPE "sparse-space-collapse" 22 23 using namespace mlir; 24 using namespace sparse_tensor; 25 26 namespace { 27 28 struct CollapseSpaceInfo { 29 ExtractIterSpaceOp space; 30 IterateOp loop; 31 }; 32 33 bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) { 34 auto pIterArgs = parent.getRegionIterArgs(); 35 auto nInitArgs = node.getInits(); 36 if (pIterArgs.size() != nInitArgs.size()) 37 return false; 38 39 // Two loops are collapsable if they are perfectly nested. 40 auto pYields = parent.getYieldedValues(); 41 auto nResult = node.getLoopResults().value(); 42 43 bool yieldEq = 44 llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) { 45 return std::get<0>(zipped) == std::get<1>(zipped); 46 }); 47 48 // Parent iter_args should be passed directly to the node's init_args. 49 bool iterArgEq = 50 llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) { 51 return std::get<0>(zipped) == std::get<1>(zipped); 52 }); 53 54 return yieldEq && iterArgEq; 55 } 56 57 bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse, 58 ExtractIterSpaceOp curSpace) { 59 60 auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp { 61 Value spaceVal = space.getExtractedSpace(); 62 if (spaceVal.hasOneUse()) 63 return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin()); 64 return nullptr; 65 }; 66 67 if (toCollapse.empty()) { 68 // Collapse root. 69 if (auto itOp = getIterateOpOverSpace(curSpace)) { 70 CollapseSpaceInfo &info = toCollapse.emplace_back(); 71 info.space = curSpace; 72 info.loop = itOp; 73 return true; 74 } 75 return false; 76 } 77 78 auto parent = toCollapse.back().space; 79 auto pItOp = toCollapse.back().loop; 80 auto nItOp = getIterateOpOverSpace(curSpace); 81 82 // Can only collapse spaces extracted from the same tensor. 83 if (parent.getTensor() != curSpace.getTensor()) { 84 LLVM_DEBUG({ 85 llvm::dbgs() 86 << "failed to collpase spaces extracted from different tensors."; 87 }); 88 return false; 89 } 90 91 // Can only collapse consecutive simple iteration on one tensor (i.e., no 92 // coiteration). 93 if (!nItOp || nItOp->getBlock() != curSpace->getBlock() || 94 pItOp.getIterator() != curSpace.getParentIter() || 95 curSpace->getParentOp() != pItOp.getOperation()) { 96 LLVM_DEBUG( 97 { llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; }); 98 return false; 99 } 100 101 if (pItOp && !isCollapsableLoops(pItOp, nItOp)) { 102 LLVM_DEBUG({ 103 llvm::dbgs() 104 << "failed to collapse IterateOps that are not perfectly nested."; 105 }); 106 return false; 107 } 108 109 CollapseSpaceInfo &info = toCollapse.emplace_back(); 110 info.space = curSpace; 111 info.loop = nItOp; 112 return true; 113 } 114 115 void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) { 116 if (toCollapse.size() < 2) 117 return; 118 119 ExtractIterSpaceOp root = toCollapse.front().space; 120 ExtractIterSpaceOp leaf = toCollapse.back().space; 121 Location loc = root.getLoc(); 122 123 assert(root->hasOneUse() && leaf->hasOneUse()); 124 125 // Insert collapsed operation at the same scope as root operation. 126 OpBuilder builder(root); 127 128 // Construct the collapsed iteration space. 129 auto collapsedSpace = builder.create<ExtractIterSpaceOp>( 130 loc, root.getTensor(), root.getParentIter(), root.getLoLvl(), 131 leaf.getHiLvl()); 132 133 auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin()); 134 auto innermost = toCollapse.back().loop; 135 136 IRMapping mapper; 137 mapper.map(leaf, collapsedSpace.getExtractedSpace()); 138 for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs())) 139 mapper.map(std::get<0>(z), std::get<1>(z)); 140 141 auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper)); 142 builder.setInsertionPointToStart(cloned.getBody()); 143 144 I64BitSet crdUsedLvls; 145 unsigned shift = 0, argIdx = 1; 146 for (auto info : toCollapse.drop_back()) { 147 I64BitSet set = info.loop.getCrdUsedLvls(); 148 crdUsedLvls |= set.lshift(shift); 149 shift += info.loop.getSpaceDim(); 150 for (BlockArgument crd : info.loop.getCrds()) { 151 BlockArgument collapsedCrd = cloned.getBody()->insertArgument( 152 argIdx++, builder.getIndexType(), crd.getLoc()); 153 crd.replaceAllUsesWith(collapsedCrd); 154 } 155 } 156 crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift); 157 cloned.getIterator().setType(collapsedSpace.getType().getIteratorType()); 158 cloned.setCrdUsedLvls(crdUsedLvls); 159 160 rItOp.replaceAllUsesWith(cloned.getResults()); 161 // Erase collapsed loops. 162 rItOp.erase(); 163 root.erase(); 164 } 165 166 struct SparseSpaceCollapsePass 167 : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> { 168 SparseSpaceCollapsePass() = default; 169 170 void runOnOperation() override { 171 func::FuncOp func = getOperation(); 172 173 // A naive (experimental) implementation to collapse consecutive sparse 174 // spaces. It does NOT handle complex cases where multiple spaces are 175 // extracted in the same basic block. E.g., 176 // 177 // %space1 = extract_space %t1 ... 178 // %space2 = extract_space %t2 ... 179 // sparse_tensor.iterate(%sp1) ... 180 // 181 SmallVector<CollapseSpaceInfo> toCollapse; 182 func->walk([&](ExtractIterSpaceOp op) { 183 if (!legalToCollapse(toCollapse, op)) { 184 // if not legal to collapse one more space, collapse the existing ones 185 // and clear. 186 collapseSparseSpace(toCollapse); 187 toCollapse.clear(); 188 } 189 }); 190 191 collapseSparseSpace(toCollapse); 192 } 193 }; 194 195 } // namespace 196 197 std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() { 198 return std::make_unique<SparseSpaceCollapsePass>(); 199 } 200