1 //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Utils/CodegenUtils.h" 10 11 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 12 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 14 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "llvm/Support/FormatVariadic.h" 17 18 using namespace mlir; 19 using namespace sparse_tensor; 20 21 //===----------------------------------------------------------------------===// 22 // Helper methods. 23 //===----------------------------------------------------------------------===// 24 25 // TODO: reuse StorageLayout::foreachField? 26 27 // TODO: we need COO AoS and SoA 28 29 // Convert type range to new types range, with sparse tensors externalized. 30 void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes, 31 SmallVectorImpl<Type> *extraTypes = nullptr) { 32 for (auto type : types) { 33 // All "dense" data passes through unmodified. 34 if (!getSparseTensorEncoding(type)) { 35 convTypes.push_back(type); 36 continue; 37 } 38 // Convert the external representation of the values array. 39 const SparseTensorType stt(cast<RankedTensorType>(type)); 40 auto shape = {ShapedType::kDynamic}; 41 auto vtp = RankedTensorType::get(shape, stt.getElementType()); 42 convTypes.push_back(vtp); 43 if (extraTypes) 44 extraTypes->push_back(vtp); 45 // Convert the external representations of the pos/crd arrays. 46 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { 47 const auto lt = stt.getLvlType(lvl); 48 if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { 49 auto ptp = RankedTensorType::get(shape, stt.getPosType()); 50 auto ctp = RankedTensorType::get(shape, stt.getCrdType()); 51 convTypes.push_back(ptp); 52 convTypes.push_back(ctp); 53 if (extraTypes) { 54 extraTypes->push_back(ptp); 55 extraTypes->push_back(ctp); 56 } 57 } else { 58 assert(isDenseLT(lt)); // TODO: handle other cases 59 } 60 } 61 } 62 } 63 64 // Convert input and output values to [dis[assemble ops for sparse tensors. 65 void convVals(OpBuilder &builder, Location loc, TypeRange types, 66 ValueRange fromVals, ValueRange extraVals, 67 SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn) { 68 unsigned idx = 0; 69 for (auto type : types) { 70 // All "dense" data passes through unmodified. 71 if (!getSparseTensorEncoding(type)) { 72 toVals.push_back(fromVals[idx++]); 73 continue; 74 } 75 // Convert the external representation of the values array. 76 auto rtp = cast<RankedTensorType>(type); 77 const SparseTensorType stt(rtp); 78 auto shape = {ShapedType::kDynamic}; 79 SmallVector<Value> inputs; 80 SmallVector<Type> retTypes; 81 SmallVector<Type> cntTypes; 82 // Collect the external representation of the values array for 83 // input or the outgoing sparse tensor for output. 84 inputs.push_back(fromVals[idx++]); 85 if (!isIn) { 86 inputs.push_back(extraVals[extra++]); 87 retTypes.push_back(RankedTensorType::get(shape, stt.getElementType())); 88 cntTypes.push_back(builder.getIndexType()); 89 } 90 // Collect the external representations of the pos/crd arrays. 91 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { 92 const auto lt = stt.getLvlType(lvl); 93 if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { 94 if (isIn) { 95 inputs.push_back(fromVals[idx++]); 96 inputs.push_back(fromVals[idx++]); 97 } else { 98 Type pTp = stt.getPosType(); 99 Type cTp = stt.getCrdType(); 100 inputs.push_back(extraVals[extra++]); 101 inputs.push_back(extraVals[extra++]); 102 retTypes.push_back(RankedTensorType::get(shape, pTp)); 103 retTypes.push_back(RankedTensorType::get(shape, cTp)); 104 cntTypes.push_back(pTp); 105 cntTypes.push_back(cTp); 106 } 107 } else { 108 assert(isDenseLT(lt)); // TODO: handle other cases 109 } 110 } 111 if (isIn) { 112 // Assemble multiple inputs into a single sparse tensor. 113 auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs); 114 toVals.push_back(a.getResult()); 115 } else { 116 // Disassemble a single sparse input into multiple outputs. 117 // Note that this includes the counters, which are dropped. 118 unsigned len = retTypes.size(); 119 retTypes.append(cntTypes); 120 auto d = 121 builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs); 122 for (unsigned i = 0; i < len; i++) 123 toVals.push_back(d.getResult(i)); 124 } 125 } 126 } 127 128 //===----------------------------------------------------------------------===// 129 // Rewriting rules. 130 //===----------------------------------------------------------------------===// 131 132 namespace { 133 134 // A rewriting rules that converts public entry methods that use sparse tensors 135 // as input parameters and/or output return values into wrapper methods that 136 // [dis]assemble the individual tensors that constitute the actual storage used 137 // externally into MLIR sparse tensors before calling the original method. 138 // 139 // In particular, each sparse tensor input 140 // 141 // void foo(..., t, ...) { } 142 // 143 // makes the original foo() internal and adds the following wrapper method 144 // 145 // void foo(..., t1..tn, ...) { 146 // t = assemble t1..tn 147 // _internal_foo(..., t, ...) 148 // } 149 // 150 // and likewise, each output tensor 151 // 152 // ... T ... bar(...) { return ..., t, ...; } 153 // 154 // makes the original bar() internal and adds the following wrapper method 155 // 156 // ... T1..TN ... bar(..., t1'..tn') { 157 // ..., t, ... = _internal_bar(...) 158 // t1..tn = disassemble t, t1'..tn' 159 // return ..., t1..tn, ... 160 // } 161 // 162 // TODO: refine output sparse tensors to work well with external framework 163 // 164 // TODO: use "inlining" instead of a wrapper? 165 // 166 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> { 167 using OpRewritePattern::OpRewritePattern; 168 169 LogicalResult matchAndRewrite(func::FuncOp funcOp, 170 PatternRewriter &rewriter) const override { 171 // Only rewrite public entry methods. 172 if (funcOp.isPrivate()) 173 return failure(); 174 175 // Translate sparse tensor types to external types. 176 SmallVector<Type> inputTypes; 177 SmallVector<Type> outputTypes; 178 SmallVector<Type> extraTypes; 179 convTypes(funcOp.getArgumentTypes(), inputTypes); 180 convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes); 181 182 // Only sparse inputs or outputs need a wrapper method. 183 if (inputTypes.size() == funcOp.getArgumentTypes().size() && 184 outputTypes.size() == funcOp.getResultTypes().size()) 185 return failure(); 186 187 // Modify the original method into an internal, private method. 188 auto orgName = funcOp.getName(); 189 std::string wrapper = llvm::formatv("_internal_{0}", orgName).str(); 190 funcOp.setName(wrapper); 191 funcOp.setPrivate(); 192 193 // Start the new public wrapper method with original name. 194 Location loc = funcOp.getLoc(); 195 ModuleOp modOp = funcOp->getParentOfType<ModuleOp>(); 196 MLIRContext *context = modOp.getContext(); 197 OpBuilder moduleBuilder(modOp.getBodyRegion()); 198 unsigned extra = inputTypes.size(); 199 inputTypes.append(extraTypes); 200 auto func = moduleBuilder.create<func::FuncOp>( 201 loc, orgName, FunctionType::get(context, inputTypes, outputTypes)); 202 func.setPublic(); 203 204 // Construct new wrapper method body. 205 OpBuilder::InsertionGuard insertionGuard(rewriter); 206 Block *body = func.addEntryBlock(); 207 rewriter.setInsertionPointToStart(body); 208 209 // Convert inputs. 210 SmallVector<Value> inputs; 211 convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(), 212 ValueRange(), inputs, 0, /*isIn=*/true); 213 214 // Call original, now internal method. 215 auto org = SymbolRefAttr::get(context, wrapper); 216 auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org, 217 inputs); 218 219 // Convert outputs and return. 220 SmallVector<Value> outputs; 221 convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(), 222 body->getArguments(), outputs, extra, /*isIn=*/false); 223 rewriter.create<func::ReturnOp>(loc, outputs); 224 225 // Finally, migrate a potential c-interface property. 226 if (funcOp->getAttrOfType<UnitAttr>( 227 LLVM::LLVMDialect::getEmitCWrapperAttrName())) { 228 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 229 UnitAttr::get(context)); 230 funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName()); 231 } 232 return success(); 233 } 234 }; 235 236 } // namespace 237 238 //===----------------------------------------------------------------------===// 239 // Public method for populating conversion rules. 240 //===----------------------------------------------------------------------===// 241 242 void mlir::populateSparseAssembler(RewritePatternSet &patterns) { 243 patterns.add<SparseFuncAssembler>(patterns.getContext()); 244 } 245