xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (revision fc9f1d49aae4328ef36e1d9ba606e93703fde970)
1 //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Utils/CodegenUtils.h"
10 
11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
12 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
14 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "llvm/Support/FormatVariadic.h"
17 
18 using namespace mlir;
19 using namespace sparse_tensor;
20 
21 //===----------------------------------------------------------------------===//
22 // Helper methods.
23 //===----------------------------------------------------------------------===//
24 
25 // Convert type range to new types range, with sparse tensors externalized.
26 static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
27                       SmallVectorImpl<Type> *extraTypes = nullptr) {
28   for (auto type : types) {
29     // All "dense" data passes through unmodified.
30     if (!getSparseTensorEncoding(type)) {
31       convTypes.push_back(type);
32       continue;
33     }
34 
35     // Convert the external representation of the position/coordinate array
36     const SparseTensorType stt(cast<RankedTensorType>(type));
37     foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
38                                                Type t, FieldIndex,
39                                                SparseTensorFieldKind kind,
40                                                Level, LevelType) {
41       if (kind == SparseTensorFieldKind::CrdMemRef ||
42           kind == SparseTensorFieldKind::PosMemRef ||
43           kind == SparseTensorFieldKind::ValMemRef) {
44         ShapedType st = t.cast<ShapedType>();
45         auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
46         convTypes.push_back(rtp);
47         if (extraTypes)
48           extraTypes->push_back(rtp);
49       }
50       return true;
51     });
52   }
53 }
54 
55 // Convert input and output values to [dis]assemble ops for sparse tensors.
56 static void convVals(OpBuilder &builder, Location loc, TypeRange types,
57                      ValueRange fromVals, ValueRange extraVals,
58                      SmallVectorImpl<Value> &toVals, unsigned extra,
59                      bool isIn) {
60   unsigned idx = 0;
61   for (auto type : types) {
62     // All "dense" data passes through unmodified.
63     if (!getSparseTensorEncoding(type)) {
64       toVals.push_back(fromVals[idx++]);
65       continue;
66     }
67     // Handle sparse data.
68     auto rtp = cast<RankedTensorType>(type);
69     const SparseTensorType stt(rtp);
70     SmallVector<Value> inputs;
71     SmallVector<Type> retTypes;
72     SmallVector<Type> cntTypes;
73     if (!isIn)
74       inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
75 
76     // Collect the external representations of the pos/crd arrays.
77     foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
78                                                      SparseTensorFieldKind kind,
79                                                      Level, LevelType) {
80       if (kind == SparseTensorFieldKind::CrdMemRef ||
81           kind == SparseTensorFieldKind::PosMemRef ||
82           kind == SparseTensorFieldKind::ValMemRef) {
83         if (isIn) {
84           inputs.push_back(fromVals[idx++]);
85         } else {
86           ShapedType st = t.cast<ShapedType>();
87           auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
88           inputs.push_back(extraVals[extra++]);
89           retTypes.push_back(rtp);
90           cntTypes.push_back(builder.getIndexType());
91         }
92       }
93       return true;
94     });
95 
96     if (isIn) {
97       // Assemble multiple inputs into a single sparse tensor.
98       auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
99       toVals.push_back(a.getResult());
100     } else {
101       // Disassemble a single sparse input into multiple outputs.
102       // Note that this includes the counters, which are dropped.
103       unsigned len = retTypes.size();
104       retTypes.append(cntTypes);
105       auto d =
106           builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);
107       for (unsigned i = 0; i < len; i++)
108         toVals.push_back(d.getResult(i));
109     }
110   }
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // Rewriting rules.
115 //===----------------------------------------------------------------------===//
116 
117 namespace {
118 
119 // A rewriting rules that converts public entry methods that use sparse tensors
120 // as input parameters and/or output return values into wrapper methods that
121 // [dis]assemble the individual tensors that constitute the actual storage used
122 // externally into MLIR sparse tensors before calling the original method.
123 //
124 // In particular, each sparse tensor input
125 //
126 // void foo(..., t, ...) { }
127 //
128 // makes the original foo() internal and adds the following wrapper method
129 //
130 // void foo(..., t1..tn, ...) {
131 //   t = assemble t1..tn
132 //   _internal_foo(..., t, ...)
133 // }
134 //
135 // and likewise, each output tensor
136 //
137 // ... T ... bar(...) { return ..., t, ...; }
138 //
139 // makes the original bar() internal and adds the following wrapper method
140 //
141 // ... T1..TN ... bar(..., t1'..tn') {
142 //   ..., t, ... = _internal_bar(...)
143 //   t1..tn = disassemble t, t1'..tn'
144 //   return ..., t1..tn, ...
145 // }
146 //
147 // TODO: refine output sparse tensors to work well with external framework
148 //
149 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
150   using OpRewritePattern::OpRewritePattern;
151 
152   LogicalResult matchAndRewrite(func::FuncOp funcOp,
153                                 PatternRewriter &rewriter) const override {
154     // Only rewrite public entry methods.
155     if (funcOp.isPrivate())
156       return failure();
157 
158     // Translate sparse tensor types to external types.
159     SmallVector<Type> inputTypes;
160     SmallVector<Type> outputTypes;
161     SmallVector<Type> extraTypes;
162     convTypes(funcOp.getArgumentTypes(), inputTypes);
163     convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes);
164 
165     // Only sparse inputs or outputs need a wrapper method.
166     if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
167         outputTypes.size() == funcOp.getResultTypes().size())
168       return failure();
169 
170     // Modify the original method into an internal, private method.
171     auto orgName = funcOp.getName();
172     std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
173     funcOp.setName(wrapper);
174     funcOp.setPrivate();
175 
176     // Start the new public wrapper method with original name.
177     Location loc = funcOp.getLoc();
178     ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
179     MLIRContext *context = modOp.getContext();
180     OpBuilder moduleBuilder(modOp.getBodyRegion());
181     unsigned extra = inputTypes.size();
182     inputTypes.append(extraTypes);
183     auto func = moduleBuilder.create<func::FuncOp>(
184         loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
185     func.setPublic();
186 
187     // Construct new wrapper method body.
188     OpBuilder::InsertionGuard insertionGuard(rewriter);
189     Block *body = func.addEntryBlock();
190     rewriter.setInsertionPointToStart(body);
191 
192     // Convert inputs.
193     SmallVector<Value> inputs;
194     convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
195              ValueRange(), inputs, 0, /*isIn=*/true);
196 
197     // Call the original, now private method. A subsequent inlining pass can
198     // determine whether cloning the method body in place is worthwhile.
199     auto org = SymbolRefAttr::get(context, wrapper);
200     auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
201                                               inputs);
202 
203     // Convert outputs and return.
204     SmallVector<Value> outputs;
205     convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
206              body->getArguments(), outputs, extra, /*isIn=*/false);
207     rewriter.create<func::ReturnOp>(loc, outputs);
208 
209     // Finally, migrate a potential c-interface property.
210     if (funcOp->getAttrOfType<UnitAttr>(
211             LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
212       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
213                     UnitAttr::get(context));
214       funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
215     }
216     return success();
217   }
218 };
219 
220 } // namespace
221 
222 //===----------------------------------------------------------------------===//
223 // Public method for populating conversion rules.
224 //===----------------------------------------------------------------------===//
225 
226 void mlir::populateSparseAssembler(RewritePatternSet &patterns) {
227   patterns.add<SparseFuncAssembler>(patterns.getContext());
228 }
229