133b463adSAart Bik //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===// 233b463adSAart Bik // 333b463adSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 433b463adSAart Bik // See https://llvm.org/LICENSE.txt for license information. 533b463adSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 633b463adSAart Bik // 733b463adSAart Bik //===----------------------------------------------------------------------===// 833b463adSAart Bik 933b463adSAart Bik #include "Utils/CodegenUtils.h" 1033b463adSAart Bik 115122a2c2SAart Bik #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 1233b463adSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 1333b463adSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 1433b463adSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 1533b463adSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 1633b463adSAart Bik #include "mlir/Dialect/Tensor/IR/Tensor.h" 1733b463adSAart Bik #include "llvm/Support/FormatVariadic.h" 1833b463adSAart Bik 1933b463adSAart Bik using namespace mlir; 2033b463adSAart Bik using namespace sparse_tensor; 2133b463adSAart Bik 2233b463adSAart Bik //===----------------------------------------------------------------------===// 2333b463adSAart Bik // Helper methods. 2433b463adSAart Bik //===----------------------------------------------------------------------===// 2533b463adSAart Bik 2633b463adSAart Bik // Convert type range to new types range, with sparse tensors externalized. 27*0e34dbb4SAart Bik static void convTypes(bool &hasAnnotation, TypeRange types, 28*0e34dbb4SAart Bik SmallVectorImpl<Type> &convTypes, 295122a2c2SAart Bik SmallVectorImpl<Type> *extraTypes, bool directOut) { 3033b463adSAart Bik for (auto type : types) { 3133b463adSAart Bik // All "dense" data passes through unmodified. 3233b463adSAart Bik if (!getSparseTensorEncoding(type)) { 3333b463adSAart Bik convTypes.push_back(type); 3433b463adSAart Bik continue; 3533b463adSAart Bik } 36*0e34dbb4SAart Bik hasAnnotation = true; 37f40ee6e8SPeiming Liu 385122a2c2SAart Bik // Convert the external representations of the pos/crd/val arrays. 39fc9f1d49SPeiming Liu const SparseTensorType stt(cast<RankedTensorType>(type)); 405122a2c2SAart Bik foreachFieldAndTypeInSparseTensor( 415122a2c2SAart Bik stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex, 42f40ee6e8SPeiming Liu SparseTensorFieldKind kind, 43f40ee6e8SPeiming Liu Level, LevelType) { 445122a2c2SAart Bik if (kind == SparseTensorFieldKind::PosMemRef || 455122a2c2SAart Bik kind == SparseTensorFieldKind::CrdMemRef || 46fc9f1d49SPeiming Liu kind == SparseTensorFieldKind::ValMemRef) { 47a5757c5bSChristian Sigg auto rtp = cast<ShapedType>(t); 485122a2c2SAart Bik if (!directOut) { 495122a2c2SAart Bik rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 50f40ee6e8SPeiming Liu if (extraTypes) 51f40ee6e8SPeiming Liu extraTypes->push_back(rtp); 5233b463adSAart Bik } 535122a2c2SAart Bik convTypes.push_back(rtp); 545122a2c2SAart Bik } 55f40ee6e8SPeiming Liu return true; 56f40ee6e8SPeiming Liu }); 5733b463adSAart Bik } 5833b463adSAart Bik } 5933b463adSAart Bik 604d273b94SAart Bik // Convert input and output values to [dis]assemble ops for sparse tensors. 61f40ee6e8SPeiming Liu static void convVals(OpBuilder &builder, Location loc, TypeRange types, 6233b463adSAart Bik ValueRange fromVals, ValueRange extraVals, 635122a2c2SAart Bik SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn, 645122a2c2SAart Bik bool directOut) { 6533b463adSAart Bik unsigned idx = 0; 6633b463adSAart Bik for (auto type : types) { 6733b463adSAart Bik // All "dense" data passes through unmodified. 6833b463adSAart Bik if (!getSparseTensorEncoding(type)) { 6933b463adSAart Bik toVals.push_back(fromVals[idx++]); 7033b463adSAart Bik continue; 7133b463adSAart Bik } 72fc9f1d49SPeiming Liu // Handle sparse data. 7333b463adSAart Bik auto rtp = cast<RankedTensorType>(type); 7433b463adSAart Bik const SparseTensorType stt(rtp); 7533b463adSAart Bik SmallVector<Value> inputs; 7633b463adSAart Bik SmallVector<Type> retTypes; 7733b463adSAart Bik SmallVector<Type> cntTypes; 78fc9f1d49SPeiming Liu if (!isIn) 79fc9f1d49SPeiming Liu inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble 80f40ee6e8SPeiming Liu 815122a2c2SAart Bik // Collect the external representations of the pos/crd/val arrays. 82f40ee6e8SPeiming Liu foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex, 83f40ee6e8SPeiming Liu SparseTensorFieldKind kind, 845122a2c2SAart Bik Level lv, LevelType) { 855122a2c2SAart Bik if (kind == SparseTensorFieldKind::PosMemRef || 865122a2c2SAart Bik kind == SparseTensorFieldKind::CrdMemRef || 87fc9f1d49SPeiming Liu kind == SparseTensorFieldKind::ValMemRef) { 8833b463adSAart Bik if (isIn) { 8933b463adSAart Bik inputs.push_back(fromVals[idx++]); 905122a2c2SAart Bik } else if (directOut) { 915122a2c2SAart Bik Value mem; 925122a2c2SAart Bik if (kind == SparseTensorFieldKind::PosMemRef) 935122a2c2SAart Bik mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0], 945122a2c2SAart Bik lv); 955122a2c2SAart Bik else if (kind == SparseTensorFieldKind::CrdMemRef) 965122a2c2SAart Bik mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0], 975122a2c2SAart Bik lv); 985122a2c2SAart Bik else 995122a2c2SAart Bik mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]); 1005122a2c2SAart Bik toVals.push_back(mem); 10133b463adSAart Bik } else { 102a5757c5bSChristian Sigg ShapedType rtp = cast<ShapedType>(t); 1035122a2c2SAart Bik rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 10433b463adSAart Bik inputs.push_back(extraVals[extra++]); 105f40ee6e8SPeiming Liu retTypes.push_back(rtp); 106fc9f1d49SPeiming Liu cntTypes.push_back(builder.getIndexType()); 10733b463adSAart Bik } 10833b463adSAart Bik } 109f40ee6e8SPeiming Liu return true; 110f40ee6e8SPeiming Liu }); 111f40ee6e8SPeiming Liu 11233b463adSAart Bik if (isIn) { 11333b463adSAart Bik // Assemble multiple inputs into a single sparse tensor. 11433b463adSAart Bik auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs); 11533b463adSAart Bik toVals.push_back(a.getResult()); 1165122a2c2SAart Bik } else if (!directOut) { 11733b463adSAart Bik // Disassemble a single sparse input into multiple outputs. 11833b463adSAart Bik // Note that this includes the counters, which are dropped. 11933b463adSAart Bik unsigned len = retTypes.size(); 12033b463adSAart Bik retTypes.append(cntTypes); 12133b463adSAart Bik auto d = 12233b463adSAart Bik builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs); 12333b463adSAart Bik for (unsigned i = 0; i < len; i++) 12433b463adSAart Bik toVals.push_back(d.getResult(i)); 12533b463adSAart Bik } 12633b463adSAart Bik } 12733b463adSAart Bik } 12833b463adSAart Bik 12933b463adSAart Bik //===----------------------------------------------------------------------===// 13033b463adSAart Bik // Rewriting rules. 13133b463adSAart Bik //===----------------------------------------------------------------------===// 13233b463adSAart Bik 13333b463adSAart Bik namespace { 13433b463adSAart Bik 13533b463adSAart Bik // A rewriting rules that converts public entry methods that use sparse tensors 136d00e6d07SAart Bik // as input parameters and/or output return values into wrapper methods that 137d00e6d07SAart Bik // [dis]assemble the individual tensors that constitute the actual storage used 138d00e6d07SAart Bik // externally into MLIR sparse tensors before calling the original method. 13933b463adSAart Bik // 14033b463adSAart Bik // In particular, each sparse tensor input 14133b463adSAart Bik // 14233b463adSAart Bik // void foo(..., t, ...) { } 14333b463adSAart Bik // 144d00e6d07SAart Bik // makes the original foo() internal and adds the following wrapper method 14533b463adSAart Bik // 146d00e6d07SAart Bik // void foo(..., t1..tn, ...) { 14733b463adSAart Bik // t = assemble t1..tn 148d00e6d07SAart Bik // _internal_foo(..., t, ...) 14933b463adSAart Bik // } 15033b463adSAart Bik // 15133b463adSAart Bik // and likewise, each output tensor 15233b463adSAart Bik // 15333b463adSAart Bik // ... T ... bar(...) { return ..., t, ...; } 15433b463adSAart Bik // 155d00e6d07SAart Bik // makes the original bar() internal and adds the following wrapper method 15633b463adSAart Bik // 157d00e6d07SAart Bik // ... T1..TN ... bar(..., t1'..tn') { 158d00e6d07SAart Bik // ..., t, ... = _internal_bar(...) 15933b463adSAart Bik // t1..tn = disassemble t, t1'..tn' 16033b463adSAart Bik // return ..., t1..tn, ... 16133b463adSAart Bik // } 16233b463adSAart Bik // 1635122a2c2SAart Bik // (with a direct-out variant without the disassemble). 16433b463adSAart Bik // 16533b463adSAart Bik struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> { 16633b463adSAart Bik using OpRewritePattern::OpRewritePattern; 16733b463adSAart Bik 1685122a2c2SAart Bik SparseFuncAssembler(MLIRContext *context, bool dO) 1695122a2c2SAart Bik : OpRewritePattern(context), directOut(dO) {} 1705122a2c2SAart Bik 17133b463adSAart Bik LogicalResult matchAndRewrite(func::FuncOp funcOp, 17233b463adSAart Bik PatternRewriter &rewriter) const override { 173d00e6d07SAart Bik // Only rewrite public entry methods. 174d00e6d07SAart Bik if (funcOp.isPrivate()) 17533b463adSAart Bik return failure(); 17633b463adSAart Bik 17733b463adSAart Bik // Translate sparse tensor types to external types. 17833b463adSAart Bik SmallVector<Type> inputTypes; 17933b463adSAart Bik SmallVector<Type> outputTypes; 18033b463adSAart Bik SmallVector<Type> extraTypes; 181*0e34dbb4SAart Bik bool hasAnnotation = false; 182*0e34dbb4SAart Bik convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes, nullptr, 183*0e34dbb4SAart Bik false); 184*0e34dbb4SAart Bik convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes, 185*0e34dbb4SAart Bik directOut); 18633b463adSAart Bik 187d00e6d07SAart Bik // Only sparse inputs or outputs need a wrapper method. 188*0e34dbb4SAart Bik if (!hasAnnotation) 18933b463adSAart Bik return failure(); 19033b463adSAart Bik 191d00e6d07SAart Bik // Modify the original method into an internal, private method. 192d00e6d07SAart Bik auto orgName = funcOp.getName(); 193d00e6d07SAart Bik std::string wrapper = llvm::formatv("_internal_{0}", orgName).str(); 194d00e6d07SAart Bik funcOp.setName(wrapper); 195d00e6d07SAart Bik funcOp.setPrivate(); 196d00e6d07SAart Bik 197d00e6d07SAart Bik // Start the new public wrapper method with original name. 19833b463adSAart Bik Location loc = funcOp.getLoc(); 19933b463adSAart Bik ModuleOp modOp = funcOp->getParentOfType<ModuleOp>(); 20033b463adSAart Bik MLIRContext *context = modOp.getContext(); 20133b463adSAart Bik OpBuilder moduleBuilder(modOp.getBodyRegion()); 20233b463adSAart Bik unsigned extra = inputTypes.size(); 20333b463adSAart Bik inputTypes.append(extraTypes); 20433b463adSAart Bik auto func = moduleBuilder.create<func::FuncOp>( 205d00e6d07SAart Bik loc, orgName, FunctionType::get(context, inputTypes, outputTypes)); 20633b463adSAart Bik func.setPublic(); 20733b463adSAart Bik 208d00e6d07SAart Bik // Construct new wrapper method body. 20933b463adSAart Bik OpBuilder::InsertionGuard insertionGuard(rewriter); 21033b463adSAart Bik Block *body = func.addEntryBlock(); 21133b463adSAart Bik rewriter.setInsertionPointToStart(body); 21233b463adSAart Bik 21333b463adSAart Bik // Convert inputs. 21433b463adSAart Bik SmallVector<Value> inputs; 21533b463adSAart Bik convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(), 2165122a2c2SAart Bik ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut); 21733b463adSAart Bik 2184d273b94SAart Bik // Call the original, now private method. A subsequent inlining pass can 2194d273b94SAart Bik // determine whether cloning the method body in place is worthwhile. 220d00e6d07SAart Bik auto org = SymbolRefAttr::get(context, wrapper); 22133b463adSAart Bik auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org, 22233b463adSAart Bik inputs); 22333b463adSAart Bik 22433b463adSAart Bik // Convert outputs and return. 22533b463adSAart Bik SmallVector<Value> outputs; 22633b463adSAart Bik convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(), 2275122a2c2SAart Bik body->getArguments(), outputs, extra, /*isIn=*/false, directOut); 22833b463adSAart Bik rewriter.create<func::ReturnOp>(loc, outputs); 22933b463adSAart Bik 230d00e6d07SAart Bik // Finally, migrate a potential c-interface property. 231d00e6d07SAart Bik if (funcOp->getAttrOfType<UnitAttr>( 232d00e6d07SAart Bik LLVM::LLVMDialect::getEmitCWrapperAttrName())) { 233d00e6d07SAart Bik func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 234d00e6d07SAart Bik UnitAttr::get(context)); 23533b463adSAart Bik funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName()); 236d00e6d07SAart Bik } 23733b463adSAart Bik return success(); 23833b463adSAart Bik } 2395122a2c2SAart Bik 2405122a2c2SAart Bik private: 2415122a2c2SAart Bik const bool directOut; 24233b463adSAart Bik }; 24333b463adSAart Bik 24433b463adSAart Bik } // namespace 24533b463adSAart Bik 24633b463adSAart Bik //===----------------------------------------------------------------------===// 24733b463adSAart Bik // Public method for populating conversion rules. 24833b463adSAart Bik //===----------------------------------------------------------------------===// 24933b463adSAart Bik 2505122a2c2SAart Bik void mlir::populateSparseAssembler(RewritePatternSet &patterns, 2515122a2c2SAart Bik bool directOut) { 2525122a2c2SAart Bik patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut); 25333b463adSAart Bik } 254