xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (revision 0e34dbb4f452013eab89a0a8f04a436ff6c408d4)
1 //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Utils/CodegenUtils.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
15 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "llvm/Support/FormatVariadic.h"
18 
19 using namespace mlir;
20 using namespace sparse_tensor;
21 
22 //===----------------------------------------------------------------------===//
23 // Helper methods.
24 //===----------------------------------------------------------------------===//
25 
26 // Convert type range to new types range, with sparse tensors externalized.
27 static void convTypes(bool &hasAnnotation, TypeRange types,
28                       SmallVectorImpl<Type> &convTypes,
29                       SmallVectorImpl<Type> *extraTypes, bool directOut) {
30   for (auto type : types) {
31     // All "dense" data passes through unmodified.
32     if (!getSparseTensorEncoding(type)) {
33       convTypes.push_back(type);
34       continue;
35     }
36     hasAnnotation = true;
37 
38     // Convert the external representations of the pos/crd/val arrays.
39     const SparseTensorType stt(cast<RankedTensorType>(type));
40     foreachFieldAndTypeInSparseTensor(
41         stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
42                                                  SparseTensorFieldKind kind,
43                                                  Level, LevelType) {
44           if (kind == SparseTensorFieldKind::PosMemRef ||
45               kind == SparseTensorFieldKind::CrdMemRef ||
46               kind == SparseTensorFieldKind::ValMemRef) {
47             auto rtp = cast<ShapedType>(t);
48             if (!directOut) {
49               rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
50               if (extraTypes)
51                 extraTypes->push_back(rtp);
52             }
53             convTypes.push_back(rtp);
54           }
55           return true;
56         });
57   }
58 }
59 
60 // Convert input and output values to [dis]assemble ops for sparse tensors.
61 static void convVals(OpBuilder &builder, Location loc, TypeRange types,
62                      ValueRange fromVals, ValueRange extraVals,
63                      SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
64                      bool directOut) {
65   unsigned idx = 0;
66   for (auto type : types) {
67     // All "dense" data passes through unmodified.
68     if (!getSparseTensorEncoding(type)) {
69       toVals.push_back(fromVals[idx++]);
70       continue;
71     }
72     // Handle sparse data.
73     auto rtp = cast<RankedTensorType>(type);
74     const SparseTensorType stt(rtp);
75     SmallVector<Value> inputs;
76     SmallVector<Type> retTypes;
77     SmallVector<Type> cntTypes;
78     if (!isIn)
79       inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
80 
81     // Collect the external representations of the pos/crd/val arrays.
82     foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
83                                                      SparseTensorFieldKind kind,
84                                                      Level lv, LevelType) {
85       if (kind == SparseTensorFieldKind::PosMemRef ||
86           kind == SparseTensorFieldKind::CrdMemRef ||
87           kind == SparseTensorFieldKind::ValMemRef) {
88         if (isIn) {
89           inputs.push_back(fromVals[idx++]);
90         } else if (directOut) {
91           Value mem;
92           if (kind == SparseTensorFieldKind::PosMemRef)
93             mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
94                                                                lv);
95           else if (kind == SparseTensorFieldKind::CrdMemRef)
96             mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
97                                                                  lv);
98           else
99             mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
100           toVals.push_back(mem);
101         } else {
102           ShapedType rtp = cast<ShapedType>(t);
103           rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
104           inputs.push_back(extraVals[extra++]);
105           retTypes.push_back(rtp);
106           cntTypes.push_back(builder.getIndexType());
107         }
108       }
109       return true;
110     });
111 
112     if (isIn) {
113       // Assemble multiple inputs into a single sparse tensor.
114       auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
115       toVals.push_back(a.getResult());
116     } else if (!directOut) {
117       // Disassemble a single sparse input into multiple outputs.
118       // Note that this includes the counters, which are dropped.
119       unsigned len = retTypes.size();
120       retTypes.append(cntTypes);
121       auto d =
122           builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);
123       for (unsigned i = 0; i < len; i++)
124         toVals.push_back(d.getResult(i));
125     }
126   }
127 }
128 
129 //===----------------------------------------------------------------------===//
130 // Rewriting rules.
131 //===----------------------------------------------------------------------===//
132 
133 namespace {
134 
135 // A rewriting rules that converts public entry methods that use sparse tensors
136 // as input parameters and/or output return values into wrapper methods that
137 // [dis]assemble the individual tensors that constitute the actual storage used
138 // externally into MLIR sparse tensors before calling the original method.
139 //
140 // In particular, each sparse tensor input
141 //
142 // void foo(..., t, ...) { }
143 //
144 // makes the original foo() internal and adds the following wrapper method
145 //
146 // void foo(..., t1..tn, ...) {
147 //   t = assemble t1..tn
148 //   _internal_foo(..., t, ...)
149 // }
150 //
151 // and likewise, each output tensor
152 //
153 // ... T ... bar(...) { return ..., t, ...; }
154 //
155 // makes the original bar() internal and adds the following wrapper method
156 //
157 // ... T1..TN ... bar(..., t1'..tn') {
158 //   ..., t, ... = _internal_bar(...)
159 //   t1..tn = disassemble t, t1'..tn'
160 //   return ..., t1..tn, ...
161 // }
162 //
163 // (with a direct-out variant without the disassemble).
164 //
165 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
166   using OpRewritePattern::OpRewritePattern;
167 
168   SparseFuncAssembler(MLIRContext *context, bool dO)
169       : OpRewritePattern(context), directOut(dO) {}
170 
171   LogicalResult matchAndRewrite(func::FuncOp funcOp,
172                                 PatternRewriter &rewriter) const override {
173     // Only rewrite public entry methods.
174     if (funcOp.isPrivate())
175       return failure();
176 
177     // Translate sparse tensor types to external types.
178     SmallVector<Type> inputTypes;
179     SmallVector<Type> outputTypes;
180     SmallVector<Type> extraTypes;
181     bool hasAnnotation = false;
182     convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes, nullptr,
183               false);
184     convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes,
185               directOut);
186 
187     // Only sparse inputs or outputs need a wrapper method.
188     if (!hasAnnotation)
189       return failure();
190 
191     // Modify the original method into an internal, private method.
192     auto orgName = funcOp.getName();
193     std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
194     funcOp.setName(wrapper);
195     funcOp.setPrivate();
196 
197     // Start the new public wrapper method with original name.
198     Location loc = funcOp.getLoc();
199     ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
200     MLIRContext *context = modOp.getContext();
201     OpBuilder moduleBuilder(modOp.getBodyRegion());
202     unsigned extra = inputTypes.size();
203     inputTypes.append(extraTypes);
204     auto func = moduleBuilder.create<func::FuncOp>(
205         loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
206     func.setPublic();
207 
208     // Construct new wrapper method body.
209     OpBuilder::InsertionGuard insertionGuard(rewriter);
210     Block *body = func.addEntryBlock();
211     rewriter.setInsertionPointToStart(body);
212 
213     // Convert inputs.
214     SmallVector<Value> inputs;
215     convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
216              ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
217 
218     // Call the original, now private method. A subsequent inlining pass can
219     // determine whether cloning the method body in place is worthwhile.
220     auto org = SymbolRefAttr::get(context, wrapper);
221     auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
222                                               inputs);
223 
224     // Convert outputs and return.
225     SmallVector<Value> outputs;
226     convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
227              body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
228     rewriter.create<func::ReturnOp>(loc, outputs);
229 
230     // Finally, migrate a potential c-interface property.
231     if (funcOp->getAttrOfType<UnitAttr>(
232             LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
233       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
234                     UnitAttr::get(context));
235       funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
236     }
237     return success();
238   }
239 
240 private:
241   const bool directOut;
242 };
243 
244 } // namespace
245 
246 //===----------------------------------------------------------------------===//
247 // Public method for populating conversion rules.
248 //===----------------------------------------------------------------------===//
249 
250 void mlir::populateSparseAssembler(RewritePatternSet &patterns,
251                                    bool directOut) {
252   patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
253 }
254