1 //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Utils/CodegenUtils.h" 10 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 14 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 15 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "llvm/Support/FormatVariadic.h" 18 19 using namespace mlir; 20 using namespace sparse_tensor; 21 22 //===----------------------------------------------------------------------===// 23 // Helper methods. 24 //===----------------------------------------------------------------------===// 25 26 // Convert type range to new types range, with sparse tensors externalized. 27 static void convTypes(bool &hasAnnotation, TypeRange types, 28 SmallVectorImpl<Type> &convTypes, 29 SmallVectorImpl<Type> *extraTypes, bool directOut) { 30 for (auto type : types) { 31 // All "dense" data passes through unmodified. 32 if (!getSparseTensorEncoding(type)) { 33 convTypes.push_back(type); 34 continue; 35 } 36 hasAnnotation = true; 37 38 // Convert the external representations of the pos/crd/val arrays. 39 const SparseTensorType stt(cast<RankedTensorType>(type)); 40 foreachFieldAndTypeInSparseTensor( 41 stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex, 42 SparseTensorFieldKind kind, 43 Level, LevelType) { 44 if (kind == SparseTensorFieldKind::PosMemRef || 45 kind == SparseTensorFieldKind::CrdMemRef || 46 kind == SparseTensorFieldKind::ValMemRef) { 47 auto rtp = cast<ShapedType>(t); 48 if (!directOut) { 49 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 50 if (extraTypes) 51 extraTypes->push_back(rtp); 52 } 53 convTypes.push_back(rtp); 54 } 55 return true; 56 }); 57 } 58 } 59 60 // Convert input and output values to [dis]assemble ops for sparse tensors. 61 static void convVals(OpBuilder &builder, Location loc, TypeRange types, 62 ValueRange fromVals, ValueRange extraVals, 63 SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn, 64 bool directOut) { 65 unsigned idx = 0; 66 for (auto type : types) { 67 // All "dense" data passes through unmodified. 68 if (!getSparseTensorEncoding(type)) { 69 toVals.push_back(fromVals[idx++]); 70 continue; 71 } 72 // Handle sparse data. 73 auto rtp = cast<RankedTensorType>(type); 74 const SparseTensorType stt(rtp); 75 SmallVector<Value> inputs; 76 SmallVector<Type> retTypes; 77 SmallVector<Type> cntTypes; 78 if (!isIn) 79 inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble 80 81 // Collect the external representations of the pos/crd/val arrays. 82 foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex, 83 SparseTensorFieldKind kind, 84 Level lv, LevelType) { 85 if (kind == SparseTensorFieldKind::PosMemRef || 86 kind == SparseTensorFieldKind::CrdMemRef || 87 kind == SparseTensorFieldKind::ValMemRef) { 88 if (isIn) { 89 inputs.push_back(fromVals[idx++]); 90 } else if (directOut) { 91 Value mem; 92 if (kind == SparseTensorFieldKind::PosMemRef) 93 mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0], 94 lv); 95 else if (kind == SparseTensorFieldKind::CrdMemRef) 96 mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0], 97 lv); 98 else 99 mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]); 100 toVals.push_back(mem); 101 } else { 102 ShapedType rtp = cast<ShapedType>(t); 103 rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); 104 inputs.push_back(extraVals[extra++]); 105 retTypes.push_back(rtp); 106 cntTypes.push_back(builder.getIndexType()); 107 } 108 } 109 return true; 110 }); 111 112 if (isIn) { 113 // Assemble multiple inputs into a single sparse tensor. 114 auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs); 115 toVals.push_back(a.getResult()); 116 } else if (!directOut) { 117 // Disassemble a single sparse input into multiple outputs. 118 // Note that this includes the counters, which are dropped. 119 unsigned len = retTypes.size(); 120 retTypes.append(cntTypes); 121 auto d = 122 builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs); 123 for (unsigned i = 0; i < len; i++) 124 toVals.push_back(d.getResult(i)); 125 } 126 } 127 } 128 129 //===----------------------------------------------------------------------===// 130 // Rewriting rules. 131 //===----------------------------------------------------------------------===// 132 133 namespace { 134 135 // A rewriting rules that converts public entry methods that use sparse tensors 136 // as input parameters and/or output return values into wrapper methods that 137 // [dis]assemble the individual tensors that constitute the actual storage used 138 // externally into MLIR sparse tensors before calling the original method. 139 // 140 // In particular, each sparse tensor input 141 // 142 // void foo(..., t, ...) { } 143 // 144 // makes the original foo() internal and adds the following wrapper method 145 // 146 // void foo(..., t1..tn, ...) { 147 // t = assemble t1..tn 148 // _internal_foo(..., t, ...) 149 // } 150 // 151 // and likewise, each output tensor 152 // 153 // ... T ... bar(...) { return ..., t, ...; } 154 // 155 // makes the original bar() internal and adds the following wrapper method 156 // 157 // ... T1..TN ... bar(..., t1'..tn') { 158 // ..., t, ... = _internal_bar(...) 159 // t1..tn = disassemble t, t1'..tn' 160 // return ..., t1..tn, ... 161 // } 162 // 163 // (with a direct-out variant without the disassemble). 164 // 165 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> { 166 using OpRewritePattern::OpRewritePattern; 167 168 SparseFuncAssembler(MLIRContext *context, bool dO) 169 : OpRewritePattern(context), directOut(dO) {} 170 171 LogicalResult matchAndRewrite(func::FuncOp funcOp, 172 PatternRewriter &rewriter) const override { 173 // Only rewrite public entry methods. 174 if (funcOp.isPrivate()) 175 return failure(); 176 177 // Translate sparse tensor types to external types. 178 SmallVector<Type> inputTypes; 179 SmallVector<Type> outputTypes; 180 SmallVector<Type> extraTypes; 181 bool hasAnnotation = false; 182 convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes, nullptr, 183 false); 184 convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes, 185 directOut); 186 187 // Only sparse inputs or outputs need a wrapper method. 188 if (!hasAnnotation) 189 return failure(); 190 191 // Modify the original method into an internal, private method. 192 auto orgName = funcOp.getName(); 193 std::string wrapper = llvm::formatv("_internal_{0}", orgName).str(); 194 funcOp.setName(wrapper); 195 funcOp.setPrivate(); 196 197 // Start the new public wrapper method with original name. 198 Location loc = funcOp.getLoc(); 199 ModuleOp modOp = funcOp->getParentOfType<ModuleOp>(); 200 MLIRContext *context = modOp.getContext(); 201 OpBuilder moduleBuilder(modOp.getBodyRegion()); 202 unsigned extra = inputTypes.size(); 203 inputTypes.append(extraTypes); 204 auto func = moduleBuilder.create<func::FuncOp>( 205 loc, orgName, FunctionType::get(context, inputTypes, outputTypes)); 206 func.setPublic(); 207 208 // Construct new wrapper method body. 209 OpBuilder::InsertionGuard insertionGuard(rewriter); 210 Block *body = func.addEntryBlock(); 211 rewriter.setInsertionPointToStart(body); 212 213 // Convert inputs. 214 SmallVector<Value> inputs; 215 convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(), 216 ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut); 217 218 // Call the original, now private method. A subsequent inlining pass can 219 // determine whether cloning the method body in place is worthwhile. 220 auto org = SymbolRefAttr::get(context, wrapper); 221 auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org, 222 inputs); 223 224 // Convert outputs and return. 225 SmallVector<Value> outputs; 226 convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(), 227 body->getArguments(), outputs, extra, /*isIn=*/false, directOut); 228 rewriter.create<func::ReturnOp>(loc, outputs); 229 230 // Finally, migrate a potential c-interface property. 231 if (funcOp->getAttrOfType<UnitAttr>( 232 LLVM::LLVMDialect::getEmitCWrapperAttrName())) { 233 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 234 UnitAttr::get(context)); 235 funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName()); 236 } 237 return success(); 238 } 239 240 private: 241 const bool directOut; 242 }; 243 244 } // namespace 245 246 //===----------------------------------------------------------------------===// 247 // Public method for populating conversion rules. 248 //===----------------------------------------------------------------------===// 249 250 void mlir::populateSparseAssembler(RewritePatternSet &patterns, 251 bool directOut) { 252 patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut); 253 } 254