xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (revision 5122a2c2320c7b14f6585e63b7fc43ac82a550c2)
1 //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Utils/CodegenUtils.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
15 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "llvm/Support/FormatVariadic.h"
18 
19 using namespace mlir;
20 using namespace sparse_tensor;
21 
22 //===----------------------------------------------------------------------===//
23 // Helper methods.
24 //===----------------------------------------------------------------------===//
25 
26 // Convert type range to new types range, with sparse tensors externalized.
27 static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
28                       SmallVectorImpl<Type> *extraTypes, bool directOut) {
29   for (auto type : types) {
30     // All "dense" data passes through unmodified.
31     if (!getSparseTensorEncoding(type)) {
32       convTypes.push_back(type);
33       continue;
34     }
35 
36     // Convert the external representations of the pos/crd/val arrays.
37     const SparseTensorType stt(cast<RankedTensorType>(type));
38     foreachFieldAndTypeInSparseTensor(
39         stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
40                                                  SparseTensorFieldKind kind,
41                                                  Level, LevelType) {
42           if (kind == SparseTensorFieldKind::PosMemRef ||
43               kind == SparseTensorFieldKind::CrdMemRef ||
44               kind == SparseTensorFieldKind::ValMemRef) {
45             auto rtp = t.cast<ShapedType>();
46             if (!directOut) {
47               rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
48               if (extraTypes)
49                 extraTypes->push_back(rtp);
50             }
51             convTypes.push_back(rtp);
52           }
53           return true;
54         });
55   }
56 }
57 
58 // Convert input and output values to [dis]assemble ops for sparse tensors.
59 static void convVals(OpBuilder &builder, Location loc, TypeRange types,
60                      ValueRange fromVals, ValueRange extraVals,
61                      SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
62                      bool directOut) {
63   unsigned idx = 0;
64   for (auto type : types) {
65     // All "dense" data passes through unmodified.
66     if (!getSparseTensorEncoding(type)) {
67       toVals.push_back(fromVals[idx++]);
68       continue;
69     }
70     // Handle sparse data.
71     auto rtp = cast<RankedTensorType>(type);
72     const SparseTensorType stt(rtp);
73     SmallVector<Value> inputs;
74     SmallVector<Type> retTypes;
75     SmallVector<Type> cntTypes;
76     if (!isIn)
77       inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
78 
79     // Collect the external representations of the pos/crd/val arrays.
80     foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
81                                                      SparseTensorFieldKind kind,
82                                                      Level lv, LevelType) {
83       if (kind == SparseTensorFieldKind::PosMemRef ||
84           kind == SparseTensorFieldKind::CrdMemRef ||
85           kind == SparseTensorFieldKind::ValMemRef) {
86         if (isIn) {
87           inputs.push_back(fromVals[idx++]);
88         } else if (directOut) {
89           Value mem;
90           if (kind == SparseTensorFieldKind::PosMemRef)
91             mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
92                                                                lv);
93           else if (kind == SparseTensorFieldKind::CrdMemRef)
94             mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
95                                                                  lv);
96           else
97             mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
98           toVals.push_back(mem);
99         } else {
100           ShapedType rtp = t.cast<ShapedType>();
101           rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
102           inputs.push_back(extraVals[extra++]);
103           retTypes.push_back(rtp);
104           cntTypes.push_back(builder.getIndexType());
105         }
106       }
107       return true;
108     });
109 
110     if (isIn) {
111       // Assemble multiple inputs into a single sparse tensor.
112       auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
113       toVals.push_back(a.getResult());
114     } else if (!directOut) {
115       // Disassemble a single sparse input into multiple outputs.
116       // Note that this includes the counters, which are dropped.
117       unsigned len = retTypes.size();
118       retTypes.append(cntTypes);
119       auto d =
120           builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);
121       for (unsigned i = 0; i < len; i++)
122         toVals.push_back(d.getResult(i));
123     }
124   }
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // Rewriting rules.
129 //===----------------------------------------------------------------------===//
130 
131 namespace {
132 
133 // A rewriting rules that converts public entry methods that use sparse tensors
134 // as input parameters and/or output return values into wrapper methods that
135 // [dis]assemble the individual tensors that constitute the actual storage used
136 // externally into MLIR sparse tensors before calling the original method.
137 //
138 // In particular, each sparse tensor input
139 //
140 // void foo(..., t, ...) { }
141 //
142 // makes the original foo() internal and adds the following wrapper method
143 //
144 // void foo(..., t1..tn, ...) {
145 //   t = assemble t1..tn
146 //   _internal_foo(..., t, ...)
147 // }
148 //
149 // and likewise, each output tensor
150 //
151 // ... T ... bar(...) { return ..., t, ...; }
152 //
153 // makes the original bar() internal and adds the following wrapper method
154 //
155 // ... T1..TN ... bar(..., t1'..tn') {
156 //   ..., t, ... = _internal_bar(...)
157 //   t1..tn = disassemble t, t1'..tn'
158 //   return ..., t1..tn, ...
159 // }
160 //
161 // (with a direct-out variant without the disassemble).
162 //
163 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
164   using OpRewritePattern::OpRewritePattern;
165 
166   SparseFuncAssembler(MLIRContext *context, bool dO)
167       : OpRewritePattern(context), directOut(dO) {}
168 
169   LogicalResult matchAndRewrite(func::FuncOp funcOp,
170                                 PatternRewriter &rewriter) const override {
171     // Only rewrite public entry methods.
172     if (funcOp.isPrivate())
173       return failure();
174 
175     // Translate sparse tensor types to external types.
176     SmallVector<Type> inputTypes;
177     SmallVector<Type> outputTypes;
178     SmallVector<Type> extraTypes;
179     convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false);
180     convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);
181 
182     // Only sparse inputs or outputs need a wrapper method.
183     if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
184         outputTypes.size() == funcOp.getResultTypes().size())
185       return failure();
186 
187     // Modify the original method into an internal, private method.
188     auto orgName = funcOp.getName();
189     std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
190     funcOp.setName(wrapper);
191     funcOp.setPrivate();
192 
193     // Start the new public wrapper method with original name.
194     Location loc = funcOp.getLoc();
195     ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
196     MLIRContext *context = modOp.getContext();
197     OpBuilder moduleBuilder(modOp.getBodyRegion());
198     unsigned extra = inputTypes.size();
199     inputTypes.append(extraTypes);
200     auto func = moduleBuilder.create<func::FuncOp>(
201         loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
202     func.setPublic();
203 
204     // Construct new wrapper method body.
205     OpBuilder::InsertionGuard insertionGuard(rewriter);
206     Block *body = func.addEntryBlock();
207     rewriter.setInsertionPointToStart(body);
208 
209     // Convert inputs.
210     SmallVector<Value> inputs;
211     convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
212              ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
213 
214     // Call the original, now private method. A subsequent inlining pass can
215     // determine whether cloning the method body in place is worthwhile.
216     auto org = SymbolRefAttr::get(context, wrapper);
217     auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
218                                               inputs);
219 
220     // Convert outputs and return.
221     SmallVector<Value> outputs;
222     convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
223              body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
224     rewriter.create<func::ReturnOp>(loc, outputs);
225 
226     // Finally, migrate a potential c-interface property.
227     if (funcOp->getAttrOfType<UnitAttr>(
228             LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
229       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
230                     UnitAttr::get(context));
231       funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
232     }
233     return success();
234   }
235 
236 private:
237   const bool directOut;
238 };
239 
240 } // namespace
241 
242 //===----------------------------------------------------------------------===//
243 // Public method for populating conversion rules.
244 //===----------------------------------------------------------------------===//
245 
246 void mlir::populateSparseAssembler(RewritePatternSet &patterns,
247                                    bool directOut) {
248   patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
249 }
250