1 //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Utils/CodegenUtils.h" 10 11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 12 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 14 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "llvm/Support/FormatVariadic.h" 17 18 using namespace mlir; 19 using namespace sparse_tensor; 20 21 //===----------------------------------------------------------------------===// 22 // Helper methods. 23 //===----------------------------------------------------------------------===// 24 25 // Convert type range to new types range, with sparse tensors externalized. 26 static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes, 27 SmallVectorImpl<Type> *extraTypes = nullptr) { 28 for (auto type : types) { 29 // All "dense" data passes through unmodified. 30 if (!getSparseTensorEncoding(type)) { 31 convTypes.push_back(type); 32 continue; 33 } 34 // Convert the external representation of the values array. 35 const SparseTensorType stt(cast<RankedTensorType>(type)); 36 auto shape = stt.getBatchLvlShape(); 37 shape.push_back(ShapedType::kDynamic); 38 auto vtp = RankedTensorType::get(shape, stt.getElementType()); 39 convTypes.push_back(vtp); 40 if (extraTypes) 41 extraTypes->push_back(vtp); 42 43 // Convert the external representation of the position/coordinate array. 44 foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes]( 45 Type t, FieldIndex, 46 SparseTensorFieldKind kind, 47 Level, LevelType) { 48 if (kind == SparseTensorFieldKind::CrdMemRef || 49 kind == SparseTensorFieldKind::PosMemRef) { 50 ShapedType st = t.cast<ShapedType>(); 51 auto rtp = RankedTensorType::get(st.getShape(), st.getElementType()); 52 convTypes.push_back(rtp); 53 if (extraTypes) 54 extraTypes->push_back(rtp); 55 } 56 return true; 57 }); 58 } 59 } 60 61 // Convert input and output values to [dis]assemble ops for sparse tensors. 62 static void convVals(OpBuilder &builder, Location loc, TypeRange types, 63 ValueRange fromVals, ValueRange extraVals, 64 SmallVectorImpl<Value> &toVals, unsigned extra, 65 bool isIn) { 66 unsigned idx = 0; 67 for (auto type : types) { 68 // All "dense" data passes through unmodified. 69 if (!getSparseTensorEncoding(type)) { 70 toVals.push_back(fromVals[idx++]); 71 continue; 72 } 73 // Convert the external representation of the values array. 74 auto rtp = cast<RankedTensorType>(type); 75 const SparseTensorType stt(rtp); 76 auto shape = stt.getBatchLvlShape(); 77 shape.push_back(ShapedType::kDynamic); 78 SmallVector<Value> inputs; 79 SmallVector<Type> retTypes; 80 SmallVector<Type> cntTypes; 81 // Collect the external representation of the values array for 82 // input or the outgoing sparse tensor for output. 83 inputs.push_back(fromVals[idx++]); 84 if (!isIn) { 85 inputs.push_back(extraVals[extra++]); 86 retTypes.push_back(RankedTensorType::get(shape, stt.getElementType())); 87 cntTypes.push_back(builder.getIndexType()); // nnz 88 } 89 90 // Collect the external representations of the pos/crd arrays. 91 foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex, 92 SparseTensorFieldKind kind, 93 Level, LevelType) { 94 if (kind == SparseTensorFieldKind::CrdMemRef || 95 kind == SparseTensorFieldKind::PosMemRef) { 96 if (isIn) { 97 inputs.push_back(fromVals[idx++]); 98 } else { 99 ShapedType st = t.cast<ShapedType>(); 100 auto rtp = RankedTensorType::get(st.getShape(), st.getElementType()); 101 inputs.push_back(extraVals[extra++]); 102 retTypes.push_back(rtp); 103 cntTypes.push_back(rtp.getElementType()); 104 } 105 } 106 return true; 107 }); 108 109 if (isIn) { 110 // Assemble multiple inputs into a single sparse tensor. 111 auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs); 112 toVals.push_back(a.getResult()); 113 } else { 114 // Disassemble a single sparse input into multiple outputs. 115 // Note that this includes the counters, which are dropped. 116 unsigned len = retTypes.size(); 117 retTypes.append(cntTypes); 118 auto d = 119 builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs); 120 for (unsigned i = 0; i < len; i++) 121 toVals.push_back(d.getResult(i)); 122 } 123 } 124 } 125 126 //===----------------------------------------------------------------------===// 127 // Rewriting rules. 128 //===----------------------------------------------------------------------===// 129 130 namespace { 131 132 // A rewriting rules that converts public entry methods that use sparse tensors 133 // as input parameters and/or output return values into wrapper methods that 134 // [dis]assemble the individual tensors that constitute the actual storage used 135 // externally into MLIR sparse tensors before calling the original method. 136 // 137 // In particular, each sparse tensor input 138 // 139 // void foo(..., t, ...) { } 140 // 141 // makes the original foo() internal and adds the following wrapper method 142 // 143 // void foo(..., t1..tn, ...) { 144 // t = assemble t1..tn 145 // _internal_foo(..., t, ...) 146 // } 147 // 148 // and likewise, each output tensor 149 // 150 // ... T ... bar(...) { return ..., t, ...; } 151 // 152 // makes the original bar() internal and adds the following wrapper method 153 // 154 // ... T1..TN ... bar(..., t1'..tn') { 155 // ..., t, ... = _internal_bar(...) 156 // t1..tn = disassemble t, t1'..tn' 157 // return ..., t1..tn, ... 158 // } 159 // 160 // TODO: refine output sparse tensors to work well with external framework 161 // 162 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> { 163 using OpRewritePattern::OpRewritePattern; 164 165 LogicalResult matchAndRewrite(func::FuncOp funcOp, 166 PatternRewriter &rewriter) const override { 167 // Only rewrite public entry methods. 168 if (funcOp.isPrivate()) 169 return failure(); 170 171 // Translate sparse tensor types to external types. 172 SmallVector<Type> inputTypes; 173 SmallVector<Type> outputTypes; 174 SmallVector<Type> extraTypes; 175 convTypes(funcOp.getArgumentTypes(), inputTypes); 176 convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes); 177 178 // Only sparse inputs or outputs need a wrapper method. 179 if (inputTypes.size() == funcOp.getArgumentTypes().size() && 180 outputTypes.size() == funcOp.getResultTypes().size()) 181 return failure(); 182 183 // Modify the original method into an internal, private method. 184 auto orgName = funcOp.getName(); 185 std::string wrapper = llvm::formatv("_internal_{0}", orgName).str(); 186 funcOp.setName(wrapper); 187 funcOp.setPrivate(); 188 189 // Start the new public wrapper method with original name. 190 Location loc = funcOp.getLoc(); 191 ModuleOp modOp = funcOp->getParentOfType<ModuleOp>(); 192 MLIRContext *context = modOp.getContext(); 193 OpBuilder moduleBuilder(modOp.getBodyRegion()); 194 unsigned extra = inputTypes.size(); 195 inputTypes.append(extraTypes); 196 auto func = moduleBuilder.create<func::FuncOp>( 197 loc, orgName, FunctionType::get(context, inputTypes, outputTypes)); 198 func.setPublic(); 199 200 // Construct new wrapper method body. 201 OpBuilder::InsertionGuard insertionGuard(rewriter); 202 Block *body = func.addEntryBlock(); 203 rewriter.setInsertionPointToStart(body); 204 205 // Convert inputs. 206 SmallVector<Value> inputs; 207 convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(), 208 ValueRange(), inputs, 0, /*isIn=*/true); 209 210 // Call the original, now private method. A subsequent inlining pass can 211 // determine whether cloning the method body in place is worthwhile. 212 auto org = SymbolRefAttr::get(context, wrapper); 213 auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org, 214 inputs); 215 216 // Convert outputs and return. 217 SmallVector<Value> outputs; 218 convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(), 219 body->getArguments(), outputs, extra, /*isIn=*/false); 220 rewriter.create<func::ReturnOp>(loc, outputs); 221 222 // Finally, migrate a potential c-interface property. 223 if (funcOp->getAttrOfType<UnitAttr>( 224 LLVM::LLVMDialect::getEmitCWrapperAttrName())) { 225 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 226 UnitAttr::get(context)); 227 funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName()); 228 } 229 return success(); 230 } 231 }; 232 233 } // namespace 234 235 //===----------------------------------------------------------------------===// 236 // Public method for populating conversion rules. 237 //===----------------------------------------------------------------------===// 238 239 void mlir::populateSparseAssembler(RewritePatternSet &patterns) { 240 patterns.add<SparseFuncAssembler>(patterns.getContext()); 241 } 242