xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp (revision 785a24f1561c610ecbce7cdfbff053e0a3a7caec)
1 //===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 #include "mlir/IR/IRMapping.h"
11 #include "mlir/Transforms/Passes.h"
12 
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
15 
16 namespace mlir {
17 #define GEN_PASS_DEF_SPARSESPACECOLLAPSE
18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
19 } // namespace mlir
20 
21 #define DEBUG_TYPE "sparse-space-collapse"
22 
23 using namespace mlir;
24 using namespace sparse_tensor;
25 
26 namespace {
27 
28 struct CollapseSpaceInfo {
29   ExtractIterSpaceOp space;
30   IterateOp loop;
31 };
32 
33 bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
34   auto pIterArgs = parent.getRegionIterArgs();
35   auto nInitArgs = node.getInits();
36   if (pIterArgs.size() != nInitArgs.size())
37     return false;
38 
39   // Two loops are collapsable if they are perfectly nested.
40   auto pYields = parent.getYieldedValues();
41   auto nResult = node.getLoopResults().value();
42 
43   bool yieldEq =
44       llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
45         return std::get<0>(zipped) == std::get<1>(zipped);
46       });
47 
48   // Parent iter_args should be passed directly to the node's init_args.
49   bool iterArgEq =
50       llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
51         return std::get<0>(zipped) == std::get<1>(zipped);
52       });
53 
54   return yieldEq && iterArgEq;
55 }
56 
57 bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
58                      ExtractIterSpaceOp curSpace) {
59 
60   auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
61     Value spaceVal = space.getExtractedSpace();
62     if (spaceVal.hasOneUse())
63       return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
64     return nullptr;
65   };
66 
67   if (toCollapse.empty()) {
68     // Collapse root.
69     if (auto itOp = getIterateOpOverSpace(curSpace)) {
70       CollapseSpaceInfo &info = toCollapse.emplace_back();
71       info.space = curSpace;
72       info.loop = itOp;
73       return true;
74     }
75     return false;
76   }
77 
78   auto parent = toCollapse.back().space;
79   auto pItOp = toCollapse.back().loop;
80   auto nItOp = getIterateOpOverSpace(curSpace);
81 
82   // Can only collapse spaces extracted from the same tensor.
83   if (parent.getTensor() != curSpace.getTensor()) {
84     LLVM_DEBUG({
85       llvm::dbgs()
86           << "failed to collpase spaces extracted from different tensors.";
87     });
88     return false;
89   }
90 
91   // Can only collapse consecutive simple iteration on one tensor (i.e., no
92   // coiteration).
93   if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
94       pItOp.getIterator() != curSpace.getParentIter() ||
95       curSpace->getParentOp() != pItOp.getOperation()) {
96     LLVM_DEBUG(
97         { llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; });
98     return false;
99   }
100 
101   if (pItOp && !isCollapsableLoops(pItOp, nItOp)) {
102     LLVM_DEBUG({
103       llvm::dbgs()
104           << "failed to collapse IterateOps that are not perfectly nested.";
105     });
106     return false;
107   }
108 
109   CollapseSpaceInfo &info = toCollapse.emplace_back();
110   info.space = curSpace;
111   info.loop = nItOp;
112   return true;
113 }
114 
115 void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
116   if (toCollapse.size() < 2)
117     return;
118 
119   ExtractIterSpaceOp root = toCollapse.front().space;
120   ExtractIterSpaceOp leaf = toCollapse.back().space;
121   Location loc = root.getLoc();
122 
123   assert(root->hasOneUse() && leaf->hasOneUse());
124 
125   // Insert collapsed operation at the same scope as root operation.
126   OpBuilder builder(root);
127 
128   // Construct the collapsed iteration space.
129   auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
130       loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
131       leaf.getHiLvl());
132 
133   auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
134   auto innermost = toCollapse.back().loop;
135 
136   IRMapping mapper;
137   mapper.map(leaf, collapsedSpace.getExtractedSpace());
138   for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
139     mapper.map(std::get<0>(z), std::get<1>(z));
140 
141   auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
142   builder.setInsertionPointToStart(cloned.getBody());
143 
144   I64BitSet crdUsedLvls;
145   unsigned shift = 0, argIdx = 1;
146   for (auto info : toCollapse.drop_back()) {
147     I64BitSet set = info.loop.getCrdUsedLvls();
148     crdUsedLvls |= set.lshift(shift);
149     shift += info.loop.getSpaceDim();
150     for (BlockArgument crd : info.loop.getCrds()) {
151       BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
152           argIdx++, builder.getIndexType(), crd.getLoc());
153       crd.replaceAllUsesWith(collapsedCrd);
154     }
155   }
156   crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
157   cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
158   cloned.setCrdUsedLvls(crdUsedLvls);
159 
160   rItOp.replaceAllUsesWith(cloned.getResults());
161   // Erase collapsed loops.
162   rItOp.erase();
163   root.erase();
164 }
165 
166 struct SparseSpaceCollapsePass
167     : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
168   SparseSpaceCollapsePass() = default;
169 
170   void runOnOperation() override {
171     func::FuncOp func = getOperation();
172 
173     // A naive (experimental) implementation to collapse consecutive sparse
174     // spaces. It does NOT handle complex cases where multiple spaces are
175     // extracted in the same basic block. E.g.,
176     //
177     // %space1 = extract_space %t1 ...
178     // %space2 = extract_space %t2 ...
179     // sparse_tensor.iterate(%sp1) ...
180     //
181     SmallVector<CollapseSpaceInfo> toCollapse;
182     func->walk([&](ExtractIterSpaceOp op) {
183       if (!legalToCollapse(toCollapse, op)) {
184         // if not legal to collapse one more space, collapse the existing ones
185         // and clear.
186         collapseSparseSpace(toCollapse);
187         toCollapse.clear();
188       }
189     });
190 
191     collapseSparseSpace(toCollapse);
192   }
193 };
194 
195 } // namespace
196 
197 std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() {
198   return std::make_unique<SparseSpaceCollapsePass>();
199 }
200