1 //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Utils/CodegenUtils.h" 10 11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 12 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 14 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "llvm/Support/FormatVariadic.h" 17 18 using namespace mlir; 19 using namespace sparse_tensor; 20 21 //===----------------------------------------------------------------------===// 22 // Helper methods. 23 //===----------------------------------------------------------------------===// 24 25 // Convert type range to new types range, with sparse tensors externalized. 26 static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes, 27 SmallVectorImpl<Type> *extraTypes = nullptr) { 28 for (auto type : types) { 29 // All "dense" data passes through unmodified. 30 if (!getSparseTensorEncoding(type)) { 31 convTypes.push_back(type); 32 continue; 33 } 34 35 // Convert the external representation of the position/coordinate array 36 const SparseTensorType stt(cast<RankedTensorType>(type)); 37 foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes]( 38 Type t, FieldIndex, 39 SparseTensorFieldKind kind, 40 Level, LevelType) { 41 if (kind == SparseTensorFieldKind::CrdMemRef || 42 kind == SparseTensorFieldKind::PosMemRef || 43 kind == SparseTensorFieldKind::ValMemRef) { 44 ShapedType st = t.cast<ShapedType>(); 45 auto rtp = RankedTensorType::get(st.getShape(), st.getElementType()); 46 convTypes.push_back(rtp); 47 if (extraTypes) 48 extraTypes->push_back(rtp); 49 } 50 return true; 51 }); 52 } 53 } 54 55 // Convert input and output values to [dis]assemble ops for sparse tensors. 56 static void convVals(OpBuilder &builder, Location loc, TypeRange types, 57 ValueRange fromVals, ValueRange extraVals, 58 SmallVectorImpl<Value> &toVals, unsigned extra, 59 bool isIn) { 60 unsigned idx = 0; 61 for (auto type : types) { 62 // All "dense" data passes through unmodified. 63 if (!getSparseTensorEncoding(type)) { 64 toVals.push_back(fromVals[idx++]); 65 continue; 66 } 67 // Handle sparse data. 68 auto rtp = cast<RankedTensorType>(type); 69 const SparseTensorType stt(rtp); 70 SmallVector<Value> inputs; 71 SmallVector<Type> retTypes; 72 SmallVector<Type> cntTypes; 73 if (!isIn) 74 inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble 75 76 // Collect the external representations of the pos/crd arrays. 77 foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex, 78 SparseTensorFieldKind kind, 79 Level, LevelType) { 80 if (kind == SparseTensorFieldKind::CrdMemRef || 81 kind == SparseTensorFieldKind::PosMemRef || 82 kind == SparseTensorFieldKind::ValMemRef) { 83 if (isIn) { 84 inputs.push_back(fromVals[idx++]); 85 } else { 86 ShapedType st = t.cast<ShapedType>(); 87 auto rtp = RankedTensorType::get(st.getShape(), st.getElementType()); 88 inputs.push_back(extraVals[extra++]); 89 retTypes.push_back(rtp); 90 cntTypes.push_back(builder.getIndexType()); 91 } 92 } 93 return true; 94 }); 95 96 if (isIn) { 97 // Assemble multiple inputs into a single sparse tensor. 98 auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs); 99 toVals.push_back(a.getResult()); 100 } else { 101 // Disassemble a single sparse input into multiple outputs. 102 // Note that this includes the counters, which are dropped. 103 unsigned len = retTypes.size(); 104 retTypes.append(cntTypes); 105 auto d = 106 builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs); 107 for (unsigned i = 0; i < len; i++) 108 toVals.push_back(d.getResult(i)); 109 } 110 } 111 } 112 113 //===----------------------------------------------------------------------===// 114 // Rewriting rules. 115 //===----------------------------------------------------------------------===// 116 117 namespace { 118 119 // A rewriting rules that converts public entry methods that use sparse tensors 120 // as input parameters and/or output return values into wrapper methods that 121 // [dis]assemble the individual tensors that constitute the actual storage used 122 // externally into MLIR sparse tensors before calling the original method. 123 // 124 // In particular, each sparse tensor input 125 // 126 // void foo(..., t, ...) { } 127 // 128 // makes the original foo() internal and adds the following wrapper method 129 // 130 // void foo(..., t1..tn, ...) { 131 // t = assemble t1..tn 132 // _internal_foo(..., t, ...) 133 // } 134 // 135 // and likewise, each output tensor 136 // 137 // ... T ... bar(...) { return ..., t, ...; } 138 // 139 // makes the original bar() internal and adds the following wrapper method 140 // 141 // ... T1..TN ... bar(..., t1'..tn') { 142 // ..., t, ... = _internal_bar(...) 143 // t1..tn = disassemble t, t1'..tn' 144 // return ..., t1..tn, ... 145 // } 146 // 147 // TODO: refine output sparse tensors to work well with external framework 148 // 149 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> { 150 using OpRewritePattern::OpRewritePattern; 151 152 LogicalResult matchAndRewrite(func::FuncOp funcOp, 153 PatternRewriter &rewriter) const override { 154 // Only rewrite public entry methods. 155 if (funcOp.isPrivate()) 156 return failure(); 157 158 // Translate sparse tensor types to external types. 159 SmallVector<Type> inputTypes; 160 SmallVector<Type> outputTypes; 161 SmallVector<Type> extraTypes; 162 convTypes(funcOp.getArgumentTypes(), inputTypes); 163 convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes); 164 165 // Only sparse inputs or outputs need a wrapper method. 166 if (inputTypes.size() == funcOp.getArgumentTypes().size() && 167 outputTypes.size() == funcOp.getResultTypes().size()) 168 return failure(); 169 170 // Modify the original method into an internal, private method. 171 auto orgName = funcOp.getName(); 172 std::string wrapper = llvm::formatv("_internal_{0}", orgName).str(); 173 funcOp.setName(wrapper); 174 funcOp.setPrivate(); 175 176 // Start the new public wrapper method with original name. 177 Location loc = funcOp.getLoc(); 178 ModuleOp modOp = funcOp->getParentOfType<ModuleOp>(); 179 MLIRContext *context = modOp.getContext(); 180 OpBuilder moduleBuilder(modOp.getBodyRegion()); 181 unsigned extra = inputTypes.size(); 182 inputTypes.append(extraTypes); 183 auto func = moduleBuilder.create<func::FuncOp>( 184 loc, orgName, FunctionType::get(context, inputTypes, outputTypes)); 185 func.setPublic(); 186 187 // Construct new wrapper method body. 188 OpBuilder::InsertionGuard insertionGuard(rewriter); 189 Block *body = func.addEntryBlock(); 190 rewriter.setInsertionPointToStart(body); 191 192 // Convert inputs. 193 SmallVector<Value> inputs; 194 convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(), 195 ValueRange(), inputs, 0, /*isIn=*/true); 196 197 // Call the original, now private method. A subsequent inlining pass can 198 // determine whether cloning the method body in place is worthwhile. 199 auto org = SymbolRefAttr::get(context, wrapper); 200 auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org, 201 inputs); 202 203 // Convert outputs and return. 204 SmallVector<Value> outputs; 205 convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(), 206 body->getArguments(), outputs, extra, /*isIn=*/false); 207 rewriter.create<func::ReturnOp>(loc, outputs); 208 209 // Finally, migrate a potential c-interface property. 210 if (funcOp->getAttrOfType<UnitAttr>( 211 LLVM::LLVMDialect::getEmitCWrapperAttrName())) { 212 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 213 UnitAttr::get(context)); 214 funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName()); 215 } 216 return success(); 217 } 218 }; 219 220 } // namespace 221 222 //===----------------------------------------------------------------------===// 223 // Public method for populating conversion rules. 224 //===----------------------------------------------------------------------===// 225 226 void mlir::populateSparseAssembler(RewritePatternSet &patterns) { 227 patterns.add<SparseFuncAssembler>(patterns.getContext()); 228 } 229