1 //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Utils/CodegenUtils.h" 10 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 14 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 15 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "llvm/Support/FormatVariadic.h" 18 19 using namespace mlir; 20 using namespace sparse_tensor; 21 22 //===----------------------------------------------------------------------===// 23 // Helper methods. 24 //===----------------------------------------------------------------------===// 25 26 // Convert type range to new types range, with sparse tensors externalized. 27 static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes, 28 SmallVectorImpl<Type> *extraTypes, bool directOut) { 29 for (auto type : types) { 30 // All "dense" data passes through unmodified. 31 if (!getSparseTensorEncoding(type)) { 32 convTypes.push_back(type); 33 continue; 34 } 35 36 // Convert the external representations of the pos/crd/val arrays. 37 const SparseTensorType stt(cast<RankedTensorType>(type)); 38 foreachFieldAndTypeInSparseTensor( 39 stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex, 40 SparseTensorFieldKind kind, 41 Level, LevelType) { 42 if (kind == SparseTensorFieldKind::PosMemRef || 43 kind == SparseTensorFieldKind::CrdMemRef || 44 kind == SparseTensorFieldKind::ValMemRef) { 45 auto rtp = t.cast<ShapedType>(); 46 if (!directOut) { 47 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 48 if (extraTypes) 49 extraTypes->push_back(rtp); 50 } 51 convTypes.push_back(rtp); 52 } 53 return true; 54 }); 55 } 56 } 57 58 // Convert input and output values to [dis]assemble ops for sparse tensors. 59 static void convVals(OpBuilder &builder, Location loc, TypeRange types, 60 ValueRange fromVals, ValueRange extraVals, 61 SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn, 62 bool directOut) { 63 unsigned idx = 0; 64 for (auto type : types) { 65 // All "dense" data passes through unmodified. 66 if (!getSparseTensorEncoding(type)) { 67 toVals.push_back(fromVals[idx++]); 68 continue; 69 } 70 // Handle sparse data. 71 auto rtp = cast<RankedTensorType>(type); 72 const SparseTensorType stt(rtp); 73 SmallVector<Value> inputs; 74 SmallVector<Type> retTypes; 75 SmallVector<Type> cntTypes; 76 if (!isIn) 77 inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble 78 79 // Collect the external representations of the pos/crd/val arrays. 80 foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex, 81 SparseTensorFieldKind kind, 82 Level lv, LevelType) { 83 if (kind == SparseTensorFieldKind::PosMemRef || 84 kind == SparseTensorFieldKind::CrdMemRef || 85 kind == SparseTensorFieldKind::ValMemRef) { 86 if (isIn) { 87 inputs.push_back(fromVals[idx++]); 88 } else if (directOut) { 89 Value mem; 90 if (kind == SparseTensorFieldKind::PosMemRef) 91 mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0], 92 lv); 93 else if (kind == SparseTensorFieldKind::CrdMemRef) 94 mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0], 95 lv); 96 else 97 mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]); 98 toVals.push_back(mem); 99 } else { 100 ShapedType rtp = t.cast<ShapedType>(); 101 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 102 inputs.push_back(extraVals[extra++]); 103 retTypes.push_back(rtp); 104 cntTypes.push_back(builder.getIndexType()); 105 } 106 } 107 return true; 108 }); 109 110 if (isIn) { 111 // Assemble multiple inputs into a single sparse tensor. 112 auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs); 113 toVals.push_back(a.getResult()); 114 } else if (!directOut) { 115 // Disassemble a single sparse input into multiple outputs. 116 // Note that this includes the counters, which are dropped. 117 unsigned len = retTypes.size(); 118 retTypes.append(cntTypes); 119 auto d = 120 builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs); 121 for (unsigned i = 0; i < len; i++) 122 toVals.push_back(d.getResult(i)); 123 } 124 } 125 } 126 127 //===----------------------------------------------------------------------===// 128 // Rewriting rules. 129 //===----------------------------------------------------------------------===// 130 131 namespace { 132 133 // A rewriting rules that converts public entry methods that use sparse tensors 134 // as input parameters and/or output return values into wrapper methods that 135 // [dis]assemble the individual tensors that constitute the actual storage used 136 // externally into MLIR sparse tensors before calling the original method. 137 // 138 // In particular, each sparse tensor input 139 // 140 // void foo(..., t, ...) { } 141 // 142 // makes the original foo() internal and adds the following wrapper method 143 // 144 // void foo(..., t1..tn, ...) { 145 // t = assemble t1..tn 146 // _internal_foo(..., t, ...) 147 // } 148 // 149 // and likewise, each output tensor 150 // 151 // ... T ... bar(...) { return ..., t, ...; } 152 // 153 // makes the original bar() internal and adds the following wrapper method 154 // 155 // ... T1..TN ... bar(..., t1'..tn') { 156 // ..., t, ... = _internal_bar(...) 157 // t1..tn = disassemble t, t1'..tn' 158 // return ..., t1..tn, ... 159 // } 160 // 161 // (with a direct-out variant without the disassemble). 162 // 163 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> { 164 using OpRewritePattern::OpRewritePattern; 165 166 SparseFuncAssembler(MLIRContext *context, bool dO) 167 : OpRewritePattern(context), directOut(dO) {} 168 169 LogicalResult matchAndRewrite(func::FuncOp funcOp, 170 PatternRewriter &rewriter) const override { 171 // Only rewrite public entry methods. 172 if (funcOp.isPrivate()) 173 return failure(); 174 175 // Translate sparse tensor types to external types. 176 SmallVector<Type> inputTypes; 177 SmallVector<Type> outputTypes; 178 SmallVector<Type> extraTypes; 179 convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false); 180 convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut); 181 182 // Only sparse inputs or outputs need a wrapper method. 183 if (inputTypes.size() == funcOp.getArgumentTypes().size() && 184 outputTypes.size() == funcOp.getResultTypes().size()) 185 return failure(); 186 187 // Modify the original method into an internal, private method. 188 auto orgName = funcOp.getName(); 189 std::string wrapper = llvm::formatv("_internal_{0}", orgName).str(); 190 funcOp.setName(wrapper); 191 funcOp.setPrivate(); 192 193 // Start the new public wrapper method with original name. 194 Location loc = funcOp.getLoc(); 195 ModuleOp modOp = funcOp->getParentOfType<ModuleOp>(); 196 MLIRContext *context = modOp.getContext(); 197 OpBuilder moduleBuilder(modOp.getBodyRegion()); 198 unsigned extra = inputTypes.size(); 199 inputTypes.append(extraTypes); 200 auto func = moduleBuilder.create<func::FuncOp>( 201 loc, orgName, FunctionType::get(context, inputTypes, outputTypes)); 202 func.setPublic(); 203 204 // Construct new wrapper method body. 205 OpBuilder::InsertionGuard insertionGuard(rewriter); 206 Block *body = func.addEntryBlock(); 207 rewriter.setInsertionPointToStart(body); 208 209 // Convert inputs. 210 SmallVector<Value> inputs; 211 convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(), 212 ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut); 213 214 // Call the original, now private method. A subsequent inlining pass can 215 // determine whether cloning the method body in place is worthwhile. 216 auto org = SymbolRefAttr::get(context, wrapper); 217 auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org, 218 inputs); 219 220 // Convert outputs and return. 221 SmallVector<Value> outputs; 222 convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(), 223 body->getArguments(), outputs, extra, /*isIn=*/false, directOut); 224 rewriter.create<func::ReturnOp>(loc, outputs); 225 226 // Finally, migrate a potential c-interface property. 227 if (funcOp->getAttrOfType<UnitAttr>( 228 LLVM::LLVMDialect::getEmitCWrapperAttrName())) { 229 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 230 UnitAttr::get(context)); 231 funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName()); 232 } 233 return success(); 234 } 235 236 private: 237 const bool directOut; 238 }; 239 240 } // namespace 241 242 //===----------------------------------------------------------------------===// 243 // Public method for populating conversion rules. 244 //===----------------------------------------------------------------------===// 245 246 void mlir::populateSparseAssembler(RewritePatternSet &patterns, 247 bool directOut) { 248 patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut); 249 } 250