xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (revision 4d273b948ef064230091e41cf81f4c1b91d5beb4)
1 //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Utils/CodegenUtils.h"
10 
11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
12 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
14 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "llvm/Support/FormatVariadic.h"
17 
18 using namespace mlir;
19 using namespace sparse_tensor;
20 
21 //===----------------------------------------------------------------------===//
22 // Helper methods.
23 //===----------------------------------------------------------------------===//
24 
25 // TODO: reuse StorageLayout::foreachField?
26 
27 // TODO: we need COO AoS and SoA
28 
29 // Convert type range to new types range, with sparse tensors externalized.
30 void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
31                SmallVectorImpl<Type> *extraTypes = nullptr) {
32   for (auto type : types) {
33     // All "dense" data passes through unmodified.
34     if (!getSparseTensorEncoding(type)) {
35       convTypes.push_back(type);
36       continue;
37     }
38     // Convert the external representation of the values array.
39     const SparseTensorType stt(cast<RankedTensorType>(type));
40     auto shape = {ShapedType::kDynamic};
41     auto vtp = RankedTensorType::get(shape, stt.getElementType());
42     convTypes.push_back(vtp);
43     if (extraTypes)
44       extraTypes->push_back(vtp);
45     // Convert the external representations of the pos/crd arrays.
46     for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
47       const auto lt = stt.getLvlType(lvl);
48       if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
49         auto ptp = RankedTensorType::get(shape, stt.getPosType());
50         auto ctp = RankedTensorType::get(shape, stt.getCrdType());
51         convTypes.push_back(ptp);
52         convTypes.push_back(ctp);
53         if (extraTypes) {
54           extraTypes->push_back(ptp);
55           extraTypes->push_back(ctp);
56         }
57       } else {
58         assert(isDenseLT(lt)); // TODO: handle other cases
59       }
60     }
61   }
62 }
63 
64 // Convert input and output values to [dis]assemble ops for sparse tensors.
65 void convVals(OpBuilder &builder, Location loc, TypeRange types,
66               ValueRange fromVals, ValueRange extraVals,
67               SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn) {
68   unsigned idx = 0;
69   for (auto type : types) {
70     // All "dense" data passes through unmodified.
71     if (!getSparseTensorEncoding(type)) {
72       toVals.push_back(fromVals[idx++]);
73       continue;
74     }
75     // Convert the external representation of the values array.
76     auto rtp = cast<RankedTensorType>(type);
77     const SparseTensorType stt(rtp);
78     auto shape = {ShapedType::kDynamic};
79     SmallVector<Value> inputs;
80     SmallVector<Type> retTypes;
81     SmallVector<Type> cntTypes;
82     // Collect the external representation of the values array for
83     // input or the outgoing sparse tensor for output.
84     inputs.push_back(fromVals[idx++]);
85     if (!isIn) {
86       inputs.push_back(extraVals[extra++]);
87       retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
88       cntTypes.push_back(builder.getIndexType());
89     }
90     // Collect the external representations of the pos/crd arrays.
91     for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
92       const auto lt = stt.getLvlType(lvl);
93       if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
94         if (isIn) {
95           inputs.push_back(fromVals[idx++]);
96           inputs.push_back(fromVals[idx++]);
97         } else {
98           Type pTp = stt.getPosType();
99           Type cTp = stt.getCrdType();
100           inputs.push_back(extraVals[extra++]);
101           inputs.push_back(extraVals[extra++]);
102           retTypes.push_back(RankedTensorType::get(shape, pTp));
103           retTypes.push_back(RankedTensorType::get(shape, cTp));
104           cntTypes.push_back(pTp);
105           cntTypes.push_back(cTp);
106         }
107       } else {
108         assert(isDenseLT(lt)); // TODO: handle other cases
109       }
110     }
111     if (isIn) {
112       // Assemble multiple inputs into a single sparse tensor.
113       auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
114       toVals.push_back(a.getResult());
115     } else {
116       // Disassemble a single sparse input into multiple outputs.
117       // Note that this includes the counters, which are dropped.
118       unsigned len = retTypes.size();
119       retTypes.append(cntTypes);
120       auto d =
121           builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);
122       for (unsigned i = 0; i < len; i++)
123         toVals.push_back(d.getResult(i));
124     }
125   }
126 }
127 
128 //===----------------------------------------------------------------------===//
129 // Rewriting rules.
130 //===----------------------------------------------------------------------===//
131 
132 namespace {
133 
134 // A rewriting rules that converts public entry methods that use sparse tensors
135 // as input parameters and/or output return values into wrapper methods that
136 // [dis]assemble the individual tensors that constitute the actual storage used
137 // externally into MLIR sparse tensors before calling the original method.
138 //
139 // In particular, each sparse tensor input
140 //
141 // void foo(..., t, ...) { }
142 //
143 // makes the original foo() internal and adds the following wrapper method
144 //
145 // void foo(..., t1..tn, ...) {
146 //   t = assemble t1..tn
147 //   _internal_foo(..., t, ...)
148 // }
149 //
150 // and likewise, each output tensor
151 //
152 // ... T ... bar(...) { return ..., t, ...; }
153 //
154 // makes the original bar() internal and adds the following wrapper method
155 //
156 // ... T1..TN ... bar(..., t1'..tn') {
157 //   ..., t, ... = _internal_bar(...)
158 //   t1..tn = disassemble t, t1'..tn'
159 //   return ..., t1..tn, ...
160 // }
161 //
162 // TODO: refine output sparse tensors to work well with external framework
163 //
164 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
165   using OpRewritePattern::OpRewritePattern;
166 
167   LogicalResult matchAndRewrite(func::FuncOp funcOp,
168                                 PatternRewriter &rewriter) const override {
169     // Only rewrite public entry methods.
170     if (funcOp.isPrivate())
171       return failure();
172 
173     // Translate sparse tensor types to external types.
174     SmallVector<Type> inputTypes;
175     SmallVector<Type> outputTypes;
176     SmallVector<Type> extraTypes;
177     convTypes(funcOp.getArgumentTypes(), inputTypes);
178     convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes);
179 
180     // Only sparse inputs or outputs need a wrapper method.
181     if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
182         outputTypes.size() == funcOp.getResultTypes().size())
183       return failure();
184 
185     // Modify the original method into an internal, private method.
186     auto orgName = funcOp.getName();
187     std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
188     funcOp.setName(wrapper);
189     funcOp.setPrivate();
190 
191     // Start the new public wrapper method with original name.
192     Location loc = funcOp.getLoc();
193     ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
194     MLIRContext *context = modOp.getContext();
195     OpBuilder moduleBuilder(modOp.getBodyRegion());
196     unsigned extra = inputTypes.size();
197     inputTypes.append(extraTypes);
198     auto func = moduleBuilder.create<func::FuncOp>(
199         loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
200     func.setPublic();
201 
202     // Construct new wrapper method body.
203     OpBuilder::InsertionGuard insertionGuard(rewriter);
204     Block *body = func.addEntryBlock();
205     rewriter.setInsertionPointToStart(body);
206 
207     // Convert inputs.
208     SmallVector<Value> inputs;
209     convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
210              ValueRange(), inputs, 0, /*isIn=*/true);
211 
212     // Call the original, now private method. A subsequent inlining pass can
213     // determine whether cloning the method body in place is worthwhile.
214     auto org = SymbolRefAttr::get(context, wrapper);
215     auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
216                                               inputs);
217 
218     // Convert outputs and return.
219     SmallVector<Value> outputs;
220     convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
221              body->getArguments(), outputs, extra, /*isIn=*/false);
222     rewriter.create<func::ReturnOp>(loc, outputs);
223 
224     // Finally, migrate a potential c-interface property.
225     if (funcOp->getAttrOfType<UnitAttr>(
226             LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
227       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
228                     UnitAttr::get(context));
229       funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
230     }
231     return success();
232   }
233 };
234 
235 } // namespace
236 
237 //===----------------------------------------------------------------------===//
238 // Public method for populating conversion rules.
239 //===----------------------------------------------------------------------===//
240 
241 void mlir::populateSparseAssembler(RewritePatternSet &patterns) {
242   patterns.add<SparseFuncAssembler>(patterns.getContext());
243 }
244