xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (revision 0e34dbb4f452013eab89a0a8f04a436ff6c408d4)
133b463adSAart Bik //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
233b463adSAart Bik //
333b463adSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
433b463adSAart Bik // See https://llvm.org/LICENSE.txt for license information.
533b463adSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
633b463adSAart Bik //
733b463adSAart Bik //===----------------------------------------------------------------------===//
833b463adSAart Bik 
933b463adSAart Bik #include "Utils/CodegenUtils.h"
1033b463adSAart Bik 
115122a2c2SAart Bik #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1233b463adSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1333b463adSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
1433b463adSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
1533b463adSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
1633b463adSAart Bik #include "mlir/Dialect/Tensor/IR/Tensor.h"
1733b463adSAart Bik #include "llvm/Support/FormatVariadic.h"
1833b463adSAart Bik 
1933b463adSAart Bik using namespace mlir;
2033b463adSAart Bik using namespace sparse_tensor;
2133b463adSAart Bik 
2233b463adSAart Bik //===----------------------------------------------------------------------===//
2333b463adSAart Bik // Helper methods.
2433b463adSAart Bik //===----------------------------------------------------------------------===//
2533b463adSAart Bik 
2633b463adSAart Bik // Convert type range to new types range, with sparse tensors externalized.
27*0e34dbb4SAart Bik static void convTypes(bool &hasAnnotation, TypeRange types,
28*0e34dbb4SAart Bik                       SmallVectorImpl<Type> &convTypes,
295122a2c2SAart Bik                       SmallVectorImpl<Type> *extraTypes, bool directOut) {
3033b463adSAart Bik   for (auto type : types) {
3133b463adSAart Bik     // All "dense" data passes through unmodified.
3233b463adSAart Bik     if (!getSparseTensorEncoding(type)) {
3333b463adSAart Bik       convTypes.push_back(type);
3433b463adSAart Bik       continue;
3533b463adSAart Bik     }
36*0e34dbb4SAart Bik     hasAnnotation = true;
37f40ee6e8SPeiming Liu 
385122a2c2SAart Bik     // Convert the external representations of the pos/crd/val arrays.
39fc9f1d49SPeiming Liu     const SparseTensorType stt(cast<RankedTensorType>(type));
405122a2c2SAart Bik     foreachFieldAndTypeInSparseTensor(
415122a2c2SAart Bik         stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
42f40ee6e8SPeiming Liu                                                  SparseTensorFieldKind kind,
43f40ee6e8SPeiming Liu                                                  Level, LevelType) {
445122a2c2SAart Bik           if (kind == SparseTensorFieldKind::PosMemRef ||
455122a2c2SAart Bik               kind == SparseTensorFieldKind::CrdMemRef ||
46fc9f1d49SPeiming Liu               kind == SparseTensorFieldKind::ValMemRef) {
47a5757c5bSChristian Sigg             auto rtp = cast<ShapedType>(t);
485122a2c2SAart Bik             if (!directOut) {
495122a2c2SAart Bik               rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
50f40ee6e8SPeiming Liu               if (extraTypes)
51f40ee6e8SPeiming Liu                 extraTypes->push_back(rtp);
5233b463adSAart Bik             }
535122a2c2SAart Bik             convTypes.push_back(rtp);
545122a2c2SAart Bik           }
55f40ee6e8SPeiming Liu           return true;
56f40ee6e8SPeiming Liu         });
5733b463adSAart Bik   }
5833b463adSAart Bik }
5933b463adSAart Bik 
604d273b94SAart Bik // Convert input and output values to [dis]assemble ops for sparse tensors.
61f40ee6e8SPeiming Liu static void convVals(OpBuilder &builder, Location loc, TypeRange types,
6233b463adSAart Bik                      ValueRange fromVals, ValueRange extraVals,
635122a2c2SAart Bik                      SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
645122a2c2SAart Bik                      bool directOut) {
6533b463adSAart Bik   unsigned idx = 0;
6633b463adSAart Bik   for (auto type : types) {
6733b463adSAart Bik     // All "dense" data passes through unmodified.
6833b463adSAart Bik     if (!getSparseTensorEncoding(type)) {
6933b463adSAart Bik       toVals.push_back(fromVals[idx++]);
7033b463adSAart Bik       continue;
7133b463adSAart Bik     }
72fc9f1d49SPeiming Liu     // Handle sparse data.
7333b463adSAart Bik     auto rtp = cast<RankedTensorType>(type);
7433b463adSAart Bik     const SparseTensorType stt(rtp);
7533b463adSAart Bik     SmallVector<Value> inputs;
7633b463adSAart Bik     SmallVector<Type> retTypes;
7733b463adSAart Bik     SmallVector<Type> cntTypes;
78fc9f1d49SPeiming Liu     if (!isIn)
79fc9f1d49SPeiming Liu       inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
80f40ee6e8SPeiming Liu 
815122a2c2SAart Bik     // Collect the external representations of the pos/crd/val arrays.
82f40ee6e8SPeiming Liu     foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
83f40ee6e8SPeiming Liu                                                      SparseTensorFieldKind kind,
845122a2c2SAart Bik                                                      Level lv, LevelType) {
855122a2c2SAart Bik       if (kind == SparseTensorFieldKind::PosMemRef ||
865122a2c2SAart Bik           kind == SparseTensorFieldKind::CrdMemRef ||
87fc9f1d49SPeiming Liu           kind == SparseTensorFieldKind::ValMemRef) {
8833b463adSAart Bik         if (isIn) {
8933b463adSAart Bik           inputs.push_back(fromVals[idx++]);
905122a2c2SAart Bik         } else if (directOut) {
915122a2c2SAart Bik           Value mem;
925122a2c2SAart Bik           if (kind == SparseTensorFieldKind::PosMemRef)
935122a2c2SAart Bik             mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
945122a2c2SAart Bik                                                                lv);
955122a2c2SAart Bik           else if (kind == SparseTensorFieldKind::CrdMemRef)
965122a2c2SAart Bik             mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
975122a2c2SAart Bik                                                                  lv);
985122a2c2SAart Bik           else
995122a2c2SAart Bik             mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
1005122a2c2SAart Bik           toVals.push_back(mem);
10133b463adSAart Bik         } else {
102a5757c5bSChristian Sigg           ShapedType rtp = cast<ShapedType>(t);
1035122a2c2SAart Bik           rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
10433b463adSAart Bik           inputs.push_back(extraVals[extra++]);
105f40ee6e8SPeiming Liu           retTypes.push_back(rtp);
106fc9f1d49SPeiming Liu           cntTypes.push_back(builder.getIndexType());
10733b463adSAart Bik         }
10833b463adSAart Bik       }
109f40ee6e8SPeiming Liu       return true;
110f40ee6e8SPeiming Liu     });
111f40ee6e8SPeiming Liu 
11233b463adSAart Bik     if (isIn) {
11333b463adSAart Bik       // Assemble multiple inputs into a single sparse tensor.
11433b463adSAart Bik       auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
11533b463adSAart Bik       toVals.push_back(a.getResult());
1165122a2c2SAart Bik     } else if (!directOut) {
11733b463adSAart Bik       // Disassemble a single sparse input into multiple outputs.
11833b463adSAart Bik       // Note that this includes the counters, which are dropped.
11933b463adSAart Bik       unsigned len = retTypes.size();
12033b463adSAart Bik       retTypes.append(cntTypes);
12133b463adSAart Bik       auto d =
12233b463adSAart Bik           builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);
12333b463adSAart Bik       for (unsigned i = 0; i < len; i++)
12433b463adSAart Bik         toVals.push_back(d.getResult(i));
12533b463adSAart Bik     }
12633b463adSAart Bik   }
12733b463adSAart Bik }
12833b463adSAart Bik 
12933b463adSAart Bik //===----------------------------------------------------------------------===//
13033b463adSAart Bik // Rewriting rules.
13133b463adSAart Bik //===----------------------------------------------------------------------===//
13233b463adSAart Bik 
13333b463adSAart Bik namespace {
13433b463adSAart Bik 
13533b463adSAart Bik // A rewriting rules that converts public entry methods that use sparse tensors
136d00e6d07SAart Bik // as input parameters and/or output return values into wrapper methods that
137d00e6d07SAart Bik // [dis]assemble the individual tensors that constitute the actual storage used
138d00e6d07SAart Bik // externally into MLIR sparse tensors before calling the original method.
13933b463adSAart Bik //
14033b463adSAart Bik // In particular, each sparse tensor input
14133b463adSAart Bik //
14233b463adSAart Bik // void foo(..., t, ...) { }
14333b463adSAart Bik //
144d00e6d07SAart Bik // makes the original foo() internal and adds the following wrapper method
14533b463adSAart Bik //
146d00e6d07SAart Bik // void foo(..., t1..tn, ...) {
14733b463adSAart Bik //   t = assemble t1..tn
148d00e6d07SAart Bik //   _internal_foo(..., t, ...)
14933b463adSAart Bik // }
15033b463adSAart Bik //
15133b463adSAart Bik // and likewise, each output tensor
15233b463adSAart Bik //
15333b463adSAart Bik // ... T ... bar(...) { return ..., t, ...; }
15433b463adSAart Bik //
155d00e6d07SAart Bik // makes the original bar() internal and adds the following wrapper method
15633b463adSAart Bik //
157d00e6d07SAart Bik // ... T1..TN ... bar(..., t1'..tn') {
158d00e6d07SAart Bik //   ..., t, ... = _internal_bar(...)
15933b463adSAart Bik //   t1..tn = disassemble t, t1'..tn'
16033b463adSAart Bik //   return ..., t1..tn, ...
16133b463adSAart Bik // }
16233b463adSAart Bik //
1635122a2c2SAart Bik // (with a direct-out variant without the disassemble).
16433b463adSAart Bik //
16533b463adSAart Bik struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
16633b463adSAart Bik   using OpRewritePattern::OpRewritePattern;
16733b463adSAart Bik 
1685122a2c2SAart Bik   SparseFuncAssembler(MLIRContext *context, bool dO)
1695122a2c2SAart Bik       : OpRewritePattern(context), directOut(dO) {}
1705122a2c2SAart Bik 
17133b463adSAart Bik   LogicalResult matchAndRewrite(func::FuncOp funcOp,
17233b463adSAart Bik                                 PatternRewriter &rewriter) const override {
173d00e6d07SAart Bik     // Only rewrite public entry methods.
174d00e6d07SAart Bik     if (funcOp.isPrivate())
17533b463adSAart Bik       return failure();
17633b463adSAart Bik 
17733b463adSAart Bik     // Translate sparse tensor types to external types.
17833b463adSAart Bik     SmallVector<Type> inputTypes;
17933b463adSAart Bik     SmallVector<Type> outputTypes;
18033b463adSAart Bik     SmallVector<Type> extraTypes;
181*0e34dbb4SAart Bik     bool hasAnnotation = false;
182*0e34dbb4SAart Bik     convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes, nullptr,
183*0e34dbb4SAart Bik               false);
184*0e34dbb4SAart Bik     convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes,
185*0e34dbb4SAart Bik               directOut);
18633b463adSAart Bik 
187d00e6d07SAart Bik     // Only sparse inputs or outputs need a wrapper method.
188*0e34dbb4SAart Bik     if (!hasAnnotation)
18933b463adSAart Bik       return failure();
19033b463adSAart Bik 
191d00e6d07SAart Bik     // Modify the original method into an internal, private method.
192d00e6d07SAart Bik     auto orgName = funcOp.getName();
193d00e6d07SAart Bik     std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
194d00e6d07SAart Bik     funcOp.setName(wrapper);
195d00e6d07SAart Bik     funcOp.setPrivate();
196d00e6d07SAart Bik 
197d00e6d07SAart Bik     // Start the new public wrapper method with original name.
19833b463adSAart Bik     Location loc = funcOp.getLoc();
19933b463adSAart Bik     ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
20033b463adSAart Bik     MLIRContext *context = modOp.getContext();
20133b463adSAart Bik     OpBuilder moduleBuilder(modOp.getBodyRegion());
20233b463adSAart Bik     unsigned extra = inputTypes.size();
20333b463adSAart Bik     inputTypes.append(extraTypes);
20433b463adSAart Bik     auto func = moduleBuilder.create<func::FuncOp>(
205d00e6d07SAart Bik         loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
20633b463adSAart Bik     func.setPublic();
20733b463adSAart Bik 
208d00e6d07SAart Bik     // Construct new wrapper method body.
20933b463adSAart Bik     OpBuilder::InsertionGuard insertionGuard(rewriter);
21033b463adSAart Bik     Block *body = func.addEntryBlock();
21133b463adSAart Bik     rewriter.setInsertionPointToStart(body);
21233b463adSAart Bik 
21333b463adSAart Bik     // Convert inputs.
21433b463adSAart Bik     SmallVector<Value> inputs;
21533b463adSAart Bik     convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
2165122a2c2SAart Bik              ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
21733b463adSAart Bik 
2184d273b94SAart Bik     // Call the original, now private method. A subsequent inlining pass can
2194d273b94SAart Bik     // determine whether cloning the method body in place is worthwhile.
220d00e6d07SAart Bik     auto org = SymbolRefAttr::get(context, wrapper);
22133b463adSAart Bik     auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
22233b463adSAart Bik                                               inputs);
22333b463adSAart Bik 
22433b463adSAart Bik     // Convert outputs and return.
22533b463adSAart Bik     SmallVector<Value> outputs;
22633b463adSAart Bik     convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
2275122a2c2SAart Bik              body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
22833b463adSAart Bik     rewriter.create<func::ReturnOp>(loc, outputs);
22933b463adSAart Bik 
230d00e6d07SAart Bik     // Finally, migrate a potential c-interface property.
231d00e6d07SAart Bik     if (funcOp->getAttrOfType<UnitAttr>(
232d00e6d07SAart Bik             LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
233d00e6d07SAart Bik       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
234d00e6d07SAart Bik                     UnitAttr::get(context));
23533b463adSAart Bik       funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
236d00e6d07SAart Bik     }
23733b463adSAart Bik     return success();
23833b463adSAart Bik   }
2395122a2c2SAart Bik 
2405122a2c2SAart Bik private:
2415122a2c2SAart Bik   const bool directOut;
24233b463adSAart Bik };
24333b463adSAart Bik 
24433b463adSAart Bik } // namespace
24533b463adSAart Bik 
24633b463adSAart Bik //===----------------------------------------------------------------------===//
24733b463adSAart Bik // Public method for populating conversion rules.
24833b463adSAart Bik //===----------------------------------------------------------------------===//
24933b463adSAart Bik 
2505122a2c2SAart Bik void mlir::populateSparseAssembler(RewritePatternSet &patterns,
2515122a2c2SAart Bik                                    bool directOut) {
2525122a2c2SAart Bik   patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
25333b463adSAart Bik }
254