xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp (revision 785a24f1561c610ecbce7cdfbff053e0a3a7caec)
1c6d85bafSPeiming Liu //===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
2c6d85bafSPeiming Liu //
3c6d85bafSPeiming Liu // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c6d85bafSPeiming Liu // See https://llvm.org/LICENSE.txt for license information.
5c6d85bafSPeiming Liu // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c6d85bafSPeiming Liu //
7c6d85bafSPeiming Liu //===----------------------------------------------------------------------===//
8c6d85bafSPeiming Liu 
9c6d85bafSPeiming Liu #include "mlir/Dialect/Func/IR/FuncOps.h"
10c6d85bafSPeiming Liu #include "mlir/IR/IRMapping.h"
11c6d85bafSPeiming Liu #include "mlir/Transforms/Passes.h"
12c6d85bafSPeiming Liu 
13c6d85bafSPeiming Liu #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14c6d85bafSPeiming Liu #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
15c6d85bafSPeiming Liu 
16c6d85bafSPeiming Liu namespace mlir {
17c6d85bafSPeiming Liu #define GEN_PASS_DEF_SPARSESPACECOLLAPSE
18c6d85bafSPeiming Liu #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
19c6d85bafSPeiming Liu } // namespace mlir
20c6d85bafSPeiming Liu 
21c6d85bafSPeiming Liu #define DEBUG_TYPE "sparse-space-collapse"
22c6d85bafSPeiming Liu 
23c6d85bafSPeiming Liu using namespace mlir;
24c6d85bafSPeiming Liu using namespace sparse_tensor;
25c6d85bafSPeiming Liu 
26c6d85bafSPeiming Liu namespace {
27c6d85bafSPeiming Liu 
28c6d85bafSPeiming Liu struct CollapseSpaceInfo {
29c6d85bafSPeiming Liu   ExtractIterSpaceOp space;
30c6d85bafSPeiming Liu   IterateOp loop;
31c6d85bafSPeiming Liu };
32c6d85bafSPeiming Liu 
33c6d85bafSPeiming Liu bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
34c6d85bafSPeiming Liu   auto pIterArgs = parent.getRegionIterArgs();
35c6d85bafSPeiming Liu   auto nInitArgs = node.getInits();
36c6d85bafSPeiming Liu   if (pIterArgs.size() != nInitArgs.size())
37c6d85bafSPeiming Liu     return false;
38c6d85bafSPeiming Liu 
39c6d85bafSPeiming Liu   // Two loops are collapsable if they are perfectly nested.
40c6d85bafSPeiming Liu   auto pYields = parent.getYieldedValues();
41c6d85bafSPeiming Liu   auto nResult = node.getLoopResults().value();
42c6d85bafSPeiming Liu 
43c6d85bafSPeiming Liu   bool yieldEq =
44c6d85bafSPeiming Liu       llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
45c6d85bafSPeiming Liu         return std::get<0>(zipped) == std::get<1>(zipped);
46c6d85bafSPeiming Liu       });
47c6d85bafSPeiming Liu 
48c6d85bafSPeiming Liu   // Parent iter_args should be passed directly to the node's init_args.
49c6d85bafSPeiming Liu   bool iterArgEq =
50c6d85bafSPeiming Liu       llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
51c6d85bafSPeiming Liu         return std::get<0>(zipped) == std::get<1>(zipped);
52c6d85bafSPeiming Liu       });
53c6d85bafSPeiming Liu 
54c6d85bafSPeiming Liu   return yieldEq && iterArgEq;
55c6d85bafSPeiming Liu }
56c6d85bafSPeiming Liu 
57c6d85bafSPeiming Liu bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
58c6d85bafSPeiming Liu                      ExtractIterSpaceOp curSpace) {
59c6d85bafSPeiming Liu 
60c6d85bafSPeiming Liu   auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
61c6d85bafSPeiming Liu     Value spaceVal = space.getExtractedSpace();
62c6d85bafSPeiming Liu     if (spaceVal.hasOneUse())
63c6d85bafSPeiming Liu       return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
64c6d85bafSPeiming Liu     return nullptr;
65c6d85bafSPeiming Liu   };
66c6d85bafSPeiming Liu 
67c6d85bafSPeiming Liu   if (toCollapse.empty()) {
68c6d85bafSPeiming Liu     // Collapse root.
69c6d85bafSPeiming Liu     if (auto itOp = getIterateOpOverSpace(curSpace)) {
70c6d85bafSPeiming Liu       CollapseSpaceInfo &info = toCollapse.emplace_back();
71c6d85bafSPeiming Liu       info.space = curSpace;
72c6d85bafSPeiming Liu       info.loop = itOp;
73c6d85bafSPeiming Liu       return true;
74c6d85bafSPeiming Liu     }
75c6d85bafSPeiming Liu     return false;
76c6d85bafSPeiming Liu   }
77c6d85bafSPeiming Liu 
78c6d85bafSPeiming Liu   auto parent = toCollapse.back().space;
79c6d85bafSPeiming Liu   auto pItOp = toCollapse.back().loop;
80c6d85bafSPeiming Liu   auto nItOp = getIterateOpOverSpace(curSpace);
81c6d85bafSPeiming Liu 
82c6d85bafSPeiming Liu   // Can only collapse spaces extracted from the same tensor.
83c6d85bafSPeiming Liu   if (parent.getTensor() != curSpace.getTensor()) {
84c6d85bafSPeiming Liu     LLVM_DEBUG({
85c6d85bafSPeiming Liu       llvm::dbgs()
86c6d85bafSPeiming Liu           << "failed to collpase spaces extracted from different tensors.";
87c6d85bafSPeiming Liu     });
88c6d85bafSPeiming Liu     return false;
89c6d85bafSPeiming Liu   }
90c6d85bafSPeiming Liu 
91c6d85bafSPeiming Liu   // Can only collapse consecutive simple iteration on one tensor (i.e., no
92c6d85bafSPeiming Liu   // coiteration).
93c6d85bafSPeiming Liu   if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
94c6d85bafSPeiming Liu       pItOp.getIterator() != curSpace.getParentIter() ||
95c6d85bafSPeiming Liu       curSpace->getParentOp() != pItOp.getOperation()) {
96c6d85bafSPeiming Liu     LLVM_DEBUG(
97c6d85bafSPeiming Liu         { llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; });
98c6d85bafSPeiming Liu     return false;
99c6d85bafSPeiming Liu   }
100c6d85bafSPeiming Liu 
101c6d85bafSPeiming Liu   if (pItOp && !isCollapsableLoops(pItOp, nItOp)) {
102c6d85bafSPeiming Liu     LLVM_DEBUG({
103c6d85bafSPeiming Liu       llvm::dbgs()
104c6d85bafSPeiming Liu           << "failed to collapse IterateOps that are not perfectly nested.";
105c6d85bafSPeiming Liu     });
106c6d85bafSPeiming Liu     return false;
107c6d85bafSPeiming Liu   }
108c6d85bafSPeiming Liu 
109c6d85bafSPeiming Liu   CollapseSpaceInfo &info = toCollapse.emplace_back();
110c6d85bafSPeiming Liu   info.space = curSpace;
111c6d85bafSPeiming Liu   info.loop = nItOp;
112c6d85bafSPeiming Liu   return true;
113c6d85bafSPeiming Liu }
114c6d85bafSPeiming Liu 
115c6d85bafSPeiming Liu void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
116c6d85bafSPeiming Liu   if (toCollapse.size() < 2)
117c6d85bafSPeiming Liu     return;
118c6d85bafSPeiming Liu 
119c6d85bafSPeiming Liu   ExtractIterSpaceOp root = toCollapse.front().space;
120c6d85bafSPeiming Liu   ExtractIterSpaceOp leaf = toCollapse.back().space;
121c6d85bafSPeiming Liu   Location loc = root.getLoc();
122c6d85bafSPeiming Liu 
123c6d85bafSPeiming Liu   assert(root->hasOneUse() && leaf->hasOneUse());
124c6d85bafSPeiming Liu 
125c6d85bafSPeiming Liu   // Insert collapsed operation at the same scope as root operation.
126c6d85bafSPeiming Liu   OpBuilder builder(root);
127c6d85bafSPeiming Liu 
128c6d85bafSPeiming Liu   // Construct the collapsed iteration space.
129c6d85bafSPeiming Liu   auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
130c6d85bafSPeiming Liu       loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
131c6d85bafSPeiming Liu       leaf.getHiLvl());
132c6d85bafSPeiming Liu 
133c6d85bafSPeiming Liu   auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
134c6d85bafSPeiming Liu   auto innermost = toCollapse.back().loop;
135c6d85bafSPeiming Liu 
136c6d85bafSPeiming Liu   IRMapping mapper;
137c6d85bafSPeiming Liu   mapper.map(leaf, collapsedSpace.getExtractedSpace());
138c6d85bafSPeiming Liu   for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
139c6d85bafSPeiming Liu     mapper.map(std::get<0>(z), std::get<1>(z));
140c6d85bafSPeiming Liu 
141c6d85bafSPeiming Liu   auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
142c6d85bafSPeiming Liu   builder.setInsertionPointToStart(cloned.getBody());
143c6d85bafSPeiming Liu 
144*785a24f1SPeiming Liu   I64BitSet crdUsedLvls;
145c6d85bafSPeiming Liu   unsigned shift = 0, argIdx = 1;
146c6d85bafSPeiming Liu   for (auto info : toCollapse.drop_back()) {
147*785a24f1SPeiming Liu     I64BitSet set = info.loop.getCrdUsedLvls();
148c6d85bafSPeiming Liu     crdUsedLvls |= set.lshift(shift);
149c6d85bafSPeiming Liu     shift += info.loop.getSpaceDim();
150c6d85bafSPeiming Liu     for (BlockArgument crd : info.loop.getCrds()) {
151c6d85bafSPeiming Liu       BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
152c6d85bafSPeiming Liu           argIdx++, builder.getIndexType(), crd.getLoc());
153c6d85bafSPeiming Liu       crd.replaceAllUsesWith(collapsedCrd);
154c6d85bafSPeiming Liu     }
155c6d85bafSPeiming Liu   }
156c6d85bafSPeiming Liu   crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
157c6d85bafSPeiming Liu   cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
158c6d85bafSPeiming Liu   cloned.setCrdUsedLvls(crdUsedLvls);
159c6d85bafSPeiming Liu 
160c6d85bafSPeiming Liu   rItOp.replaceAllUsesWith(cloned.getResults());
161c6d85bafSPeiming Liu   // Erase collapsed loops.
162c6d85bafSPeiming Liu   rItOp.erase();
163c6d85bafSPeiming Liu   root.erase();
164c6d85bafSPeiming Liu }
165c6d85bafSPeiming Liu 
166c6d85bafSPeiming Liu struct SparseSpaceCollapsePass
167c6d85bafSPeiming Liu     : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
168c6d85bafSPeiming Liu   SparseSpaceCollapsePass() = default;
169c6d85bafSPeiming Liu 
170c6d85bafSPeiming Liu   void runOnOperation() override {
171c6d85bafSPeiming Liu     func::FuncOp func = getOperation();
172c6d85bafSPeiming Liu 
173c6d85bafSPeiming Liu     // A naive (experimental) implementation to collapse consecutive sparse
174c6d85bafSPeiming Liu     // spaces. It does NOT handle complex cases where multiple spaces are
175c6d85bafSPeiming Liu     // extracted in the same basic block. E.g.,
176c6d85bafSPeiming Liu     //
177c6d85bafSPeiming Liu     // %space1 = extract_space %t1 ...
178c6d85bafSPeiming Liu     // %space2 = extract_space %t2 ...
179c6d85bafSPeiming Liu     // sparse_tensor.iterate(%sp1) ...
180c6d85bafSPeiming Liu     //
181c6d85bafSPeiming Liu     SmallVector<CollapseSpaceInfo> toCollapse;
182c6d85bafSPeiming Liu     func->walk([&](ExtractIterSpaceOp op) {
183c6d85bafSPeiming Liu       if (!legalToCollapse(toCollapse, op)) {
184c6d85bafSPeiming Liu         // if not legal to collapse one more space, collapse the existing ones
185c6d85bafSPeiming Liu         // and clear.
186c6d85bafSPeiming Liu         collapseSparseSpace(toCollapse);
187c6d85bafSPeiming Liu         toCollapse.clear();
188c6d85bafSPeiming Liu       }
189c6d85bafSPeiming Liu     });
190c6d85bafSPeiming Liu 
191c6d85bafSPeiming Liu     collapseSparseSpace(toCollapse);
192c6d85bafSPeiming Liu   }
193c6d85bafSPeiming Liu };
194c6d85bafSPeiming Liu 
195c6d85bafSPeiming Liu } // namespace
196c6d85bafSPeiming Liu 
197c6d85bafSPeiming Liu std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() {
198c6d85bafSPeiming Liu   return std::make_unique<SparseSpaceCollapsePass>();
199c6d85bafSPeiming Liu }
200