xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (revision 52b69aa32f5280ce600fbfea1c16a6f17a979c4d)
1 //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Utils/CodegenUtils.h"
10 
11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
12 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
14 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "llvm/Support/FormatVariadic.h"
17 
18 using namespace mlir;
19 using namespace sparse_tensor;
20 
21 //===----------------------------------------------------------------------===//
22 // Helper methods.
23 //===----------------------------------------------------------------------===//
24 
25 // Convert type range to new types range, with sparse tensors externalized.
26 static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
27                       SmallVectorImpl<Type> *extraTypes = nullptr) {
28   for (auto type : types) {
29     // All "dense" data passes through unmodified.
30     if (!getSparseTensorEncoding(type)) {
31       convTypes.push_back(type);
32       continue;
33     }
34     // Convert the external representation of the values array.
35     const SparseTensorType stt(cast<RankedTensorType>(type));
36     auto shape = stt.getBatchLvlShape();
37     shape.push_back(ShapedType::kDynamic);
38     auto vtp = RankedTensorType::get(shape, stt.getElementType());
39     convTypes.push_back(vtp);
40     if (extraTypes)
41       extraTypes->push_back(vtp);
42 
43     // Convert the external representation of the position/coordinate array.
44     foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
45                                                Type t, FieldIndex,
46                                                SparseTensorFieldKind kind,
47                                                Level, LevelType) {
48       if (kind == SparseTensorFieldKind::CrdMemRef ||
49           kind == SparseTensorFieldKind::PosMemRef) {
50         ShapedType st = t.cast<ShapedType>();
51         auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
52         convTypes.push_back(rtp);
53         if (extraTypes)
54           extraTypes->push_back(rtp);
55       }
56       return true;
57     });
58   }
59 }
60 
61 // Convert input and output values to [dis]assemble ops for sparse tensors.
62 static void convVals(OpBuilder &builder, Location loc, TypeRange types,
63                      ValueRange fromVals, ValueRange extraVals,
64                      SmallVectorImpl<Value> &toVals, unsigned extra,
65                      bool isIn) {
66   unsigned idx = 0;
67   for (auto type : types) {
68     // All "dense" data passes through unmodified.
69     if (!getSparseTensorEncoding(type)) {
70       toVals.push_back(fromVals[idx++]);
71       continue;
72     }
73     // Convert the external representation of the values array.
74     auto rtp = cast<RankedTensorType>(type);
75     const SparseTensorType stt(rtp);
76     auto shape = stt.getBatchLvlShape();
77     shape.push_back(ShapedType::kDynamic);
78     SmallVector<Value> inputs;
79     SmallVector<Type> retTypes;
80     SmallVector<Type> cntTypes;
81     // Collect the external representation of the values array for
82     // input or the outgoing sparse tensor for output.
83     inputs.push_back(fromVals[idx++]);
84     if (!isIn) {
85       inputs.push_back(extraVals[extra++]);
86       retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
87       cntTypes.push_back(builder.getIndexType()); // nnz
88     }
89 
90     // Collect the external representations of the pos/crd arrays.
91     foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
92                                                      SparseTensorFieldKind kind,
93                                                      Level, LevelType) {
94       if (kind == SparseTensorFieldKind::CrdMemRef ||
95           kind == SparseTensorFieldKind::PosMemRef) {
96         if (isIn) {
97           inputs.push_back(fromVals[idx++]);
98         } else {
99           ShapedType st = t.cast<ShapedType>();
100           auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
101           inputs.push_back(extraVals[extra++]);
102           retTypes.push_back(rtp);
103           cntTypes.push_back(rtp.getElementType());
104         }
105       }
106       return true;
107     });
108 
109     if (isIn) {
110       // Assemble multiple inputs into a single sparse tensor.
111       auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
112       toVals.push_back(a.getResult());
113     } else {
114       // Disassemble a single sparse input into multiple outputs.
115       // Note that this includes the counters, which are dropped.
116       unsigned len = retTypes.size();
117       retTypes.append(cntTypes);
118       auto d =
119           builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);
120       for (unsigned i = 0; i < len; i++)
121         toVals.push_back(d.getResult(i));
122     }
123   }
124 }
125 
126 //===----------------------------------------------------------------------===//
127 // Rewriting rules.
128 //===----------------------------------------------------------------------===//
129 
130 namespace {
131 
132 // A rewriting rules that converts public entry methods that use sparse tensors
133 // as input parameters and/or output return values into wrapper methods that
134 // [dis]assemble the individual tensors that constitute the actual storage used
135 // externally into MLIR sparse tensors before calling the original method.
136 //
137 // In particular, each sparse tensor input
138 //
139 // void foo(..., t, ...) { }
140 //
141 // makes the original foo() internal and adds the following wrapper method
142 //
143 // void foo(..., t1..tn, ...) {
144 //   t = assemble t1..tn
145 //   _internal_foo(..., t, ...)
146 // }
147 //
148 // and likewise, each output tensor
149 //
150 // ... T ... bar(...) { return ..., t, ...; }
151 //
152 // makes the original bar() internal and adds the following wrapper method
153 //
154 // ... T1..TN ... bar(..., t1'..tn') {
155 //   ..., t, ... = _internal_bar(...)
156 //   t1..tn = disassemble t, t1'..tn'
157 //   return ..., t1..tn, ...
158 // }
159 //
160 // TODO: refine output sparse tensors to work well with external framework
161 //
162 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
163   using OpRewritePattern::OpRewritePattern;
164 
165   LogicalResult matchAndRewrite(func::FuncOp funcOp,
166                                 PatternRewriter &rewriter) const override {
167     // Only rewrite public entry methods.
168     if (funcOp.isPrivate())
169       return failure();
170 
171     // Translate sparse tensor types to external types.
172     SmallVector<Type> inputTypes;
173     SmallVector<Type> outputTypes;
174     SmallVector<Type> extraTypes;
175     convTypes(funcOp.getArgumentTypes(), inputTypes);
176     convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes);
177 
178     // Only sparse inputs or outputs need a wrapper method.
179     if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
180         outputTypes.size() == funcOp.getResultTypes().size())
181       return failure();
182 
183     // Modify the original method into an internal, private method.
184     auto orgName = funcOp.getName();
185     std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
186     funcOp.setName(wrapper);
187     funcOp.setPrivate();
188 
189     // Start the new public wrapper method with original name.
190     Location loc = funcOp.getLoc();
191     ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
192     MLIRContext *context = modOp.getContext();
193     OpBuilder moduleBuilder(modOp.getBodyRegion());
194     unsigned extra = inputTypes.size();
195     inputTypes.append(extraTypes);
196     auto func = moduleBuilder.create<func::FuncOp>(
197         loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
198     func.setPublic();
199 
200     // Construct new wrapper method body.
201     OpBuilder::InsertionGuard insertionGuard(rewriter);
202     Block *body = func.addEntryBlock();
203     rewriter.setInsertionPointToStart(body);
204 
205     // Convert inputs.
206     SmallVector<Value> inputs;
207     convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
208              ValueRange(), inputs, 0, /*isIn=*/true);
209 
210     // Call the original, now private method. A subsequent inlining pass can
211     // determine whether cloning the method body in place is worthwhile.
212     auto org = SymbolRefAttr::get(context, wrapper);
213     auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
214                                               inputs);
215 
216     // Convert outputs and return.
217     SmallVector<Value> outputs;
218     convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
219              body->getArguments(), outputs, extra, /*isIn=*/false);
220     rewriter.create<func::ReturnOp>(loc, outputs);
221 
222     // Finally, migrate a potential c-interface property.
223     if (funcOp->getAttrOfType<UnitAttr>(
224             LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
225       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
226                     UnitAttr::get(context));
227       funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
228     }
229     return success();
230   }
231 };
232 
233 } // namespace
234 
235 //===----------------------------------------------------------------------===//
236 // Public method for populating conversion rules.
237 //===----------------------------------------------------------------------===//
238 
239 void mlir::populateSparseAssembler(RewritePatternSet &patterns) {
240   patterns.add<SparseFuncAssembler>(patterns.getContext());
241 }
242