1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts sparse tensor primitives into calls into a runtime 10 // support library. Sparse tensor types are converted into opaque pointers 11 // to the underlying sparse storage schemes. The use of opaque pointers 12 // together with runtime support library keeps the conversion relatively 13 // simple, but at the expense of IR opacity, which obscures opportunities 14 // for subsequent optimization of the IR. An alternative is provided by 15 // the SparseTensorCodegen pass. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "CodegenUtils.h" 20 21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 23 #include "mlir/Dialect/Linalg/Utils/Utils.h" 24 #include "mlir/Dialect/MemRef/IR/MemRef.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 30 #include "mlir/Dialect/Tensor/IR/Tensor.h" 31 #include "mlir/Transforms/DialectConversion.h" 32 33 using namespace mlir; 34 using namespace mlir::sparse_tensor; 35 36 namespace { 37 38 //===----------------------------------------------------------------------===// 39 // Helper methods. 40 //===----------------------------------------------------------------------===// 41 42 /// Maps each sparse tensor type to an opaque pointer. 43 static std::optional<Type> convertSparseTensorTypes(Type type) { 44 if (getSparseTensorEncoding(type) != nullptr) 45 return LLVM::LLVMPointerType::get(type.getContext()); 46 return std::nullopt; 47 } 48 49 /// Generates call to lookup a level-size. N.B., this only generates 50 /// the raw function call, and therefore (intentionally) does not perform 51 /// any dim<->lvl conversion or other logic. 52 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor, 53 uint64_t lvl) { 54 StringRef name = "sparseLvlSize"; 55 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)}; 56 Type iTp = builder.getIndexType(); 57 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 58 .getResult(0); 59 } 60 61 /// Generates call to lookup a dimension-size. N.B., this only generates 62 /// the raw function call, and therefore (intentionally) does not perform 63 /// any dim<->lvl conversion or other logic. 64 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor, 65 uint64_t dim) { 66 StringRef name = "sparseDimSize"; 67 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)}; 68 Type iTp = builder.getIndexType(); 69 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 70 .getResult(0); 71 } 72 73 /// Looks up a level-size by returning a statically-computed constant 74 /// (when possible), or by calling `genLvlSizeCall` (when dynamic). 75 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc, 76 SparseTensorType stt, Value tensor, 77 Level lvl) { 78 // Only sparse tensors have "levels" to query. 79 assert(stt.hasEncoding()); 80 // TODO: The following implementation only handles permutations; 81 // we'll need to generalize this to handle arbitrary AffineExpr. 82 // 83 // There's no need to assert `isPermutation` here: because 84 // `getDimPosition` checks that the expr isa `AffineDimExpr`, 85 // which is all we care about (for supporting permutations). 86 const Dimension dim = 87 stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl); 88 const Size sz = stt.getDynamicDimSize(dim); 89 if (!ShapedType::isDynamic(sz)) 90 return constantIndex(builder, loc, sz); 91 // If we cannot statically compute the size from the shape, then we 92 // must dynamically query it. (In principle we could also dynamically 93 // compute it, but since we already did so to construct the `tensor` 94 // in the first place, we might as well query rather than recompute.) 95 return genLvlSizeCall(builder, loc, tensor, lvl); 96 } 97 98 /// Looks up a dimension-size by returning a constant from the shape 99 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes 100 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes 101 /// of dense tensors). 102 static Value createOrFoldDimCall(OpBuilder &builder, Location loc, 103 SparseTensorType stt, Value tensor, 104 Dimension dim) { 105 const Size sz = stt.getDynamicDimSize(dim); 106 if (!ShapedType::isDynamic(sz)) 107 return constantIndex(builder, loc, sz); 108 if (stt.hasEncoding()) 109 return genDimSizeCall(builder, loc, tensor, dim); 110 return linalg::createOrFoldDimOp(builder, loc, tensor, dim); 111 } 112 113 /// Populates the array with the dimension-sizes of the given tensor. 114 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, 115 Value tensor, SmallVectorImpl<Value> &out) { 116 const Dimension dimRank = stt.getDimRank(); 117 out.clear(); 118 out.reserve(dimRank); 119 for (Dimension d = 0; d < dimRank; d++) 120 out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d)); 121 } 122 123 /// Returns an array with the dimension-sizes of the given tensor. 124 /// If the *tensor* parameters is null, the tensor type is assumed to have a 125 /// static shape. 126 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc, 127 SparseTensorType stt, 128 Value tensor = Value()) { 129 SmallVector<Value> out; 130 fillDimSizes(builder, loc, stt, tensor, out); 131 return out; 132 } 133 134 /// Generates an uninitialized buffer of the given size and type, 135 /// but returns it as type `memref<? x $tp>` (rather than as type 136 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, 137 /// this buffer must be explicitly deallocated by client. 138 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { 139 auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); 140 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); 141 } 142 143 /// Generates a temporary buffer for the level-types of the given encoding. 144 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, 145 SparseTensorType stt) { 146 SmallVector<Value> lvlTypes; 147 lvlTypes.reserve(stt.getLvlRank()); 148 for (const auto dlt : stt.getEncoding().getLvlTypes()) 149 lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt)); 150 return allocaBuffer(builder, loc, lvlTypes); 151 } 152 153 /// Extracts the bare (aligned) pointers that point to the tensor. 154 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc, 155 Value tensor) { 156 auto buf = genToMemref(builder, loc, tensor); 157 return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf); 158 } 159 160 /// Generates a temporary buffer for the level-types of the given encoding. 161 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc, 162 ValueRange lvlTensors, Value valTensor) { 163 SmallVector<Value> lvlBarePtrs; 164 lvlBarePtrs.reserve(lvlTensors.size() + 1); 165 // Passing in lvl buffer pointers. 166 for (const auto lvl : lvlTensors) 167 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl)); 168 169 // Passing in value buffer pointers. 170 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor)); 171 Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>( 172 loc, allocaBuffer(builder, loc, lvlBarePtrs)); 173 Value idxCast = 174 builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr); 175 return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder), 176 idxCast); 177 } 178 179 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`: 180 /// the "swiss army knife" method of the sparse runtime support library 181 /// for materializing sparse tensors into the computation. This abstraction 182 /// reduces the need for modifications when the API changes. 183 class NewCallParams final { 184 public: 185 /// Allocates the `ValueRange` for the `func::CallOp` parameters. 186 NewCallParams(OpBuilder &builder, Location loc) 187 : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {} 188 189 /// Initializes all static parameters (i.e., those which indicate 190 /// type-level information such as the encoding and sizes), generating 191 /// MLIR buffers as needed, and returning `this` for method chaining. 192 NewCallParams &genBuffers(SparseTensorType stt, 193 ArrayRef<Value> dimSizesValues, 194 Value dimSizesBuffer = Value()) { 195 assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank())); 196 // Sparsity annotations. 197 params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt); 198 // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers. 199 params[kParamDimSizes] = dimSizesBuffer 200 ? dimSizesBuffer 201 : allocaBuffer(builder, loc, dimSizesValues); 202 params[kParamLvlSizes] = 203 genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes], 204 params[kParamDim2Lvl], params[kParamLvl2Dim]); 205 // Secondary and primary types encoding. 206 const auto enc = stt.getEncoding(); 207 params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc); 208 params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc); 209 params[kParamValTp] = 210 constantPrimaryTypeEncoding(builder, loc, stt.getElementType()); 211 // Return `this` for method chaining. 212 return *this; 213 } 214 215 /// Checks whether all the static parameters have been initialized. 216 bool isInitialized() const { 217 for (unsigned i = 0; i < kNumStaticParams; ++i) 218 if (!params[i]) 219 return false; 220 return true; 221 } 222 223 /// Generates a function call, with the current static parameters 224 /// and the given dynamic arguments. 225 Value genNewCall(Action action, Value ptr = Value()) { 226 assert(isInitialized() && "Must initialize before genNewCall"); 227 StringRef name = "newSparseTensor"; 228 params[kParamAction] = constantAction(builder, loc, action); 229 params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp); 230 return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On) 231 .getResult(0); 232 } 233 234 private: 235 static constexpr unsigned kNumStaticParams = 8; 236 static constexpr unsigned kNumDynamicParams = 2; 237 static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams; 238 static constexpr unsigned kParamDimSizes = 0; 239 static constexpr unsigned kParamLvlSizes = 1; 240 static constexpr unsigned kParamLvlTypes = 2; 241 static constexpr unsigned kParamDim2Lvl = 3; 242 static constexpr unsigned kParamLvl2Dim = 4; 243 static constexpr unsigned kParamPosTp = 5; 244 static constexpr unsigned kParamCrdTp = 6; 245 static constexpr unsigned kParamValTp = 7; 246 static constexpr unsigned kParamAction = 8; 247 static constexpr unsigned kParamPtr = 9; 248 249 OpBuilder &builder; 250 Location loc; 251 Type pTp; 252 Value params[kNumParams]; 253 }; 254 255 /// Generates a call to obtain the values array. 256 static Value genValuesCall(OpBuilder &builder, Location loc, 257 SparseTensorType stt, Value ptr) { 258 auto eltTp = stt.getElementType(); 259 auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp); 260 SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)}; 261 return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On) 262 .getResult(0); 263 } 264 265 /// Generates a call to obtain the positions array. 266 static Value genPositionsCall(OpBuilder &builder, Location loc, 267 SparseTensorType stt, Value ptr, Level l) { 268 Type posTp = stt.getPosType(); 269 auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp); 270 Value lvl = constantIndex(builder, loc, l); 271 SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)}; 272 return createFuncCall(builder, loc, name, resTp, {ptr, lvl}, 273 EmitCInterface::On) 274 .getResult(0); 275 } 276 277 /// Generates a call to obtain the coordindates array. 278 static Value genCoordinatesCall(OpBuilder &builder, Location loc, 279 SparseTensorType stt, Value ptr, Level l) { 280 Type crdTp = stt.getCrdType(); 281 auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp); 282 Value lvl = constantIndex(builder, loc, l); 283 SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)}; 284 return createFuncCall(builder, loc, name, resTp, {ptr, lvl}, 285 EmitCInterface::On) 286 .getResult(0); 287 } 288 289 //===----------------------------------------------------------------------===// 290 // Conversion rules. 291 //===----------------------------------------------------------------------===// 292 293 /// Sparse conversion rule for returns. 294 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 295 public: 296 using OpConversionPattern::OpConversionPattern; 297 LogicalResult 298 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 299 ConversionPatternRewriter &rewriter) const override { 300 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 301 return success(); 302 } 303 }; 304 305 /// Sparse conversion rule for accessing level-sizes. 306 class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> { 307 public: 308 using OpConversionPattern::OpConversionPattern; 309 LogicalResult 310 matchAndRewrite(LvlOp op, OpAdaptor adaptor, 311 ConversionPatternRewriter &rewriter) const override { 312 const auto stt = getSparseTensorType(op.getSource()); 313 // Only rewrite sparse DimOp. 314 if (!stt.hasEncoding()) 315 return failure(); 316 317 // Only rewrite DimOp with constant index. 318 std::optional<int64_t> lvl = op.getConstantLvlIndex(); 319 320 if (!lvl) 321 return failure(); 322 323 // By now, if the level size is constant, the operation should have already 324 // been folded by LvlOp's folder, so we generate the call unconditionally. 325 Value src = adaptor.getOperands()[0]; 326 rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl)); 327 return success(); 328 } 329 }; 330 331 /// Sparse conversion rule for trivial tensor casts. 332 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 333 public: 334 using OpConversionPattern::OpConversionPattern; 335 LogicalResult 336 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 337 ConversionPatternRewriter &rewriter) const override { 338 // Only rewrite identically annotated source/dest. 339 auto encDst = getSparseTensorEncoding(op.getType()); 340 auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 341 if (!encDst || encDst != encSrc) 342 return failure(); 343 rewriter.replaceOp(op, adaptor.getOperands()); 344 return success(); 345 } 346 }; 347 348 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> { 349 public: 350 using OpConversionPattern::OpConversionPattern; 351 LogicalResult 352 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor, 353 ConversionPatternRewriter &rewriter) const override { 354 // Simply fold the operation. 355 rewriter.replaceOp(op, adaptor.getSource()); 356 return success(); 357 } 358 }; 359 360 /// Sparse conversion rule for the new operator. 361 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 362 public: 363 using OpConversionPattern::OpConversionPattern; 364 LogicalResult 365 matchAndRewrite(NewOp op, OpAdaptor adaptor, 366 ConversionPatternRewriter &rewriter) const override { 367 Location loc = op.getLoc(); 368 const auto stt = getSparseTensorType(op); 369 if (!stt.hasEncoding()) 370 return failure(); 371 // Construct the `reader` opening method calls. 372 SmallVector<Value> dimShapesValues; 373 Value dimSizesBuffer; 374 Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0], 375 dimShapesValues, dimSizesBuffer); 376 // Use the `reader` to parse the file. 377 Value tensor = NewCallParams(rewriter, loc) 378 .genBuffers(stt, dimShapesValues, dimSizesBuffer) 379 .genNewCall(Action::kFromReader, reader); 380 // Free the memory for `reader`. 381 createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, 382 EmitCInterface::Off); 383 rewriter.replaceOp(op, tensor); 384 return success(); 385 } 386 }; 387 388 /// Sparse conversion rule for the alloc operator. 389 /// TODO(springerm): remove when bufferization.alloc_tensor is gone 390 class SparseTensorAllocConverter 391 : public OpConversionPattern<bufferization::AllocTensorOp> { 392 public: 393 using OpConversionPattern::OpConversionPattern; 394 LogicalResult 395 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, 396 ConversionPatternRewriter &rewriter) const override { 397 const auto stt = getSparseTensorType(op); 398 if (!stt.hasEncoding()) 399 return failure(); 400 if (op.getCopy()) 401 return rewriter.notifyMatchFailure(op, "alloc copy not implemented"); 402 // Gather all dimension sizes as SSA values. 403 Location loc = op.getLoc(); 404 const Dimension dimRank = stt.getDimRank(); 405 SmallVector<Value> dimSizes; 406 dimSizes.reserve(dimRank); 407 unsigned operandCtr = 0; 408 for (Dimension d = 0; d < dimRank; d++) { 409 dimSizes.push_back( 410 stt.isDynamicDim(d) 411 ? adaptor.getOperands()[operandCtr++] 412 : constantIndex(rewriter, loc, op.getStaticSize(d))); 413 } 414 // Generate the call to construct empty tensor. The sizes are 415 // explicitly defined by the arguments to the alloc operator. 416 rewriter.replaceOp(op, NewCallParams(rewriter, loc) 417 .genBuffers(stt, dimSizes) 418 .genNewCall(Action::kEmpty)); 419 return success(); 420 } 421 }; 422 423 /// Sparse conversion rule for the empty tensor. 424 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { 425 public: 426 using OpConversionPattern::OpConversionPattern; 427 LogicalResult 428 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, 429 ConversionPatternRewriter &rewriter) const override { 430 Location loc = op.getLoc(); 431 const auto stt = getSparseTensorType(op); 432 if (!stt.hasEncoding()) 433 return failure(); 434 // Gather all dimension sizes as SSA values. 435 const Dimension dimRank = stt.getDimRank(); 436 SmallVector<Value> dimSizes; 437 dimSizes.reserve(dimRank); 438 auto shape = op.getType().getShape(); 439 unsigned operandCtr = 0; 440 for (Dimension d = 0; d < dimRank; d++) { 441 dimSizes.push_back(stt.isDynamicDim(d) 442 ? adaptor.getOperands()[operandCtr++] 443 : constantIndex(rewriter, loc, shape[d])); 444 } 445 // Generate the call to construct empty tensor. The sizes are 446 // explicitly defined by the arguments to the alloc operator. 447 rewriter.replaceOp(op, NewCallParams(rewriter, loc) 448 .genBuffers(stt, dimSizes) 449 .genNewCall(Action::kEmpty)); 450 return success(); 451 } 452 }; 453 454 /// Sparse conversion rule for the convert operator. 455 class SparseTensorReorderCOOConverter 456 : public OpConversionPattern<ReorderCOOOp> { 457 public: 458 using OpConversionPattern::OpConversionPattern; 459 460 LogicalResult 461 matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor, 462 ConversionPatternRewriter &rewriter) const override { 463 const Location loc = op->getLoc(); 464 const auto srcTp = getSparseTensorType(op.getInputCoo()); 465 const auto dstTp = getSparseTensorType(op); 466 467 const Value src = adaptor.getInputCoo(); 468 469 NewCallParams params(rewriter, loc); 470 SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src); 471 rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizes) 472 .genNewCall(Action::kSortCOOInPlace, src)); 473 474 return success(); 475 } 476 }; 477 478 /// Sparse conversion rule for the dealloc operator. 479 class SparseTensorDeallocConverter 480 : public OpConversionPattern<bufferization::DeallocTensorOp> { 481 public: 482 using OpConversionPattern::OpConversionPattern; 483 LogicalResult 484 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, 485 ConversionPatternRewriter &rewriter) const override { 486 if (!getSparseTensorType(op.getTensor()).hasEncoding()) 487 return failure(); 488 StringRef name = "delSparseTensor"; 489 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 490 EmitCInterface::Off); 491 rewriter.eraseOp(op); 492 return success(); 493 } 494 }; 495 496 /// Sparse conversion rule for position accesses. 497 class SparseTensorToPositionsConverter 498 : public OpConversionPattern<ToPositionsOp> { 499 public: 500 using OpConversionPattern::OpConversionPattern; 501 LogicalResult 502 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor, 503 ConversionPatternRewriter &rewriter) const override { 504 auto stt = getSparseTensorType(op.getTensor()); 505 auto poss = genPositionsCall(rewriter, op.getLoc(), stt, 506 adaptor.getTensor(), op.getLevel()); 507 rewriter.replaceOp(op, poss); 508 return success(); 509 } 510 }; 511 512 /// Sparse conversion rule for coordinate accesses. 513 class SparseTensorToCoordinatesConverter 514 : public OpConversionPattern<ToCoordinatesOp> { 515 public: 516 using OpConversionPattern::OpConversionPattern; 517 LogicalResult 518 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor, 519 ConversionPatternRewriter &rewriter) const override { 520 auto stt = getSparseTensorType(op.getTensor()); 521 auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt, 522 adaptor.getTensor(), op.getLevel()); 523 // Cast the MemRef type to the type expected by the users, though these 524 // two types should be compatible at runtime. 525 if (op.getType() != crds.getType()) 526 crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds); 527 rewriter.replaceOp(op, crds); 528 return success(); 529 } 530 }; 531 532 /// Sparse conversion rule for value accesses. 533 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 534 public: 535 using OpConversionPattern::OpConversionPattern; 536 LogicalResult 537 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 538 ConversionPatternRewriter &rewriter) const override { 539 auto stt = getSparseTensorType(op.getTensor()); 540 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor()); 541 rewriter.replaceOp(op, vals); 542 return success(); 543 } 544 }; 545 546 /// Sparse conversion rule for number of entries operator. 547 class SparseNumberOfEntriesConverter 548 : public OpConversionPattern<NumberOfEntriesOp> { 549 public: 550 using OpConversionPattern::OpConversionPattern; 551 LogicalResult 552 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, 553 ConversionPatternRewriter &rewriter) const override { 554 // Query values array size for the actually stored values size. 555 auto stt = getSparseTensorType(op.getTensor()); 556 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor()); 557 auto zero = constantIndex(rewriter, op.getLoc(), 0); 558 rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero); 559 return success(); 560 } 561 }; 562 563 /// Sparse conversion rule for tensor rematerialization. 564 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 565 public: 566 using OpConversionPattern::OpConversionPattern; 567 LogicalResult 568 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 569 ConversionPatternRewriter &rewriter) const override { 570 if (op.getHasInserts()) { 571 // Finalize any pending insertions. 572 StringRef name = "endLexInsert"; 573 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 574 EmitCInterface::Off); 575 } 576 rewriter.replaceOp(op, adaptor.getOperands()); 577 return success(); 578 } 579 }; 580 581 /// Sparse conversion rule for the insertion operator. 582 class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> { 583 public: 584 using OpConversionPattern::OpConversionPattern; 585 LogicalResult 586 matchAndRewrite(InsertOp op, OpAdaptor adaptor, 587 ConversionPatternRewriter &rewriter) const override { 588 // Note that the current regime only allows for strict lexicographic 589 // coordinate order. All values are passed by reference through stack 590 // allocated memrefs. 591 Location loc = op->getLoc(); 592 const auto stt = getSparseTensorType(op.getTensor()); 593 const auto elemTp = stt.getElementType(); 594 const Level lvlRank = stt.getLvlRank(); 595 Value lvlCoords, vref; 596 { 597 OpBuilder::InsertionGuard guard(rewriter); 598 Operation *loop = op; 599 // Finds the outermost loop. 600 while (auto l = loop->getParentOfType<LoopLikeOpInterface>()) 601 loop = l; 602 603 if (llvm::isa<LoopLikeOpInterface>(loop)) { 604 // Hoists alloca outside the loop to avoid stack overflow. 605 rewriter.setInsertionPoint(loop); 606 } 607 lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 608 vref = genAllocaScalar(rewriter, loc, elemTp); 609 } 610 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords()); 611 rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref); 612 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 613 createFuncCall(rewriter, loc, name, {}, 614 {adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On); 615 rewriter.replaceOp(op, adaptor.getTensor()); 616 return success(); 617 } 618 }; 619 620 /// Sparse conversion rule for the expand operator. 621 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 622 public: 623 using OpConversionPattern::OpConversionPattern; 624 LogicalResult 625 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 626 ConversionPatternRewriter &rewriter) const override { 627 Location loc = op->getLoc(); 628 const auto srcTp = getSparseTensorType(op.getTensor()); 629 Type eltType = srcTp.getElementType(); 630 Type boolType = rewriter.getIntegerType(1); 631 Type idxType = rewriter.getIndexType(); 632 // All initialization should be done on entry of the loop nest. 633 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); 634 // Get the cardinality of valid coordinates for the innermost level. 635 Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(), 636 srcTp.getLvlRank() - 1); 637 // Allocate temporary buffers for values, filled-switch, and coordinates. 638 // We do not use stack buffers for this, since the expanded size may 639 // be rather large (as it envelops a single expanded dense dimension). 640 Value values = genAlloc(rewriter, loc, sz, eltType); 641 Value filled = genAlloc(rewriter, loc, sz, boolType); 642 Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType); 643 Value zero = constantZero(rewriter, loc, idxType); 644 // Reset the values/filled-switch to all-zero/false. Note that this 645 // introduces an O(N) operation into the computation, but this reset 646 // operation is amortized over the innermost loops for the access 647 // pattern expansion. As noted in the operation doc, we would like 648 // to amortize this setup cost even between kernels. 649 rewriter.create<linalg::FillOp>( 650 loc, ValueRange{constantZero(rewriter, loc, eltType)}, 651 ValueRange{values}); 652 rewriter.create<linalg::FillOp>( 653 loc, ValueRange{constantZero(rewriter, loc, boolType)}, 654 ValueRange{filled}); 655 // Replace expansion op with these buffers and initial coordinate. 656 assert(op.getNumResults() == 4); 657 rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero}); 658 return success(); 659 } 660 }; 661 662 /// Sparse conversion rule for the compress operator. 663 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 664 public: 665 using OpConversionPattern::OpConversionPattern; 666 LogicalResult 667 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 668 ConversionPatternRewriter &rewriter) const override { 669 Location loc = op->getLoc(); 670 // Note that this method call resets the values/filled-switch back to 671 // all-zero/false by only iterating over the set elements, so the 672 // complexity remains proportional to the sparsity of the expanded 673 // access pattern. 674 Value values = adaptor.getValues(); 675 Value filled = adaptor.getFilled(); 676 Value added = adaptor.getAdded(); 677 Value count = adaptor.getCount(); 678 Value tensor = adaptor.getTensor(); 679 const auto stt = getSparseTensorType(op.getTensor()); 680 const Type elemTp = stt.getElementType(); 681 const Level lvlRank = stt.getLvlRank(); 682 auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 683 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords()); 684 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 685 createFuncCall(rewriter, loc, name, {}, 686 {tensor, lvlCoords, values, filled, added, count}, 687 EmitCInterface::On); 688 rewriter.replaceOp(op, adaptor.getTensor()); 689 // Deallocate the buffers on exit of the loop nest. 690 Operation *parent = getTop(op); 691 rewriter.setInsertionPointAfter(parent); 692 rewriter.create<memref::DeallocOp>(loc, values); 693 rewriter.create<memref::DeallocOp>(loc, filled); 694 rewriter.create<memref::DeallocOp>(loc, added); 695 return success(); 696 } 697 }; 698 699 /// Sparse conversion rule for the sparse_tensor.assemble operator. 700 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> { 701 public: 702 using OpConversionPattern::OpConversionPattern; 703 LogicalResult 704 matchAndRewrite(AssembleOp op, OpAdaptor adaptor, 705 ConversionPatternRewriter &rewriter) const override { 706 const Location loc = op->getLoc(); 707 const auto dstTp = getSparseTensorType(op.getResult()); 708 assert(dstTp.hasStaticDimShape()); 709 SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp); 710 // Use a library method to transfer the external buffers from 711 // clients to the internal SparseTensorStorage. Since we cannot 712 // assume clients transfer ownership of the buffers, this method 713 // will copy all data over into a new SparseTensorStorage. 714 Value dst = 715 NewCallParams(rewriter, loc) 716 .genBuffers(dstTp.withoutDimToLvl(), dimSizes) 717 .genNewCall(Action::kPack, 718 genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(), 719 adaptor.getValues())); 720 rewriter.replaceOp(op, dst); 721 return success(); 722 } 723 }; 724 725 /// Sparse conversion rule for the sparse_tensor.disassemble operator. 726 class SparseTensorDisassembleConverter 727 : public OpConversionPattern<DisassembleOp> { 728 public: 729 using OpConversionPattern::OpConversionPattern; 730 LogicalResult 731 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor, 732 ConversionPatternRewriter &rewriter) const override { 733 // We simply expose the buffers to the external client. This 734 // assumes the client only reads the buffers (usually copying it 735 // to the external data structures, such as numpy arrays). 736 Location loc = op->getLoc(); 737 auto stt = getSparseTensorType(op.getTensor()); 738 SmallVector<Value> retVal; 739 SmallVector<Value> retLen; 740 // Get the values buffer first. 741 auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor()); 742 auto valLenTp = op.getValLen().getType(); 743 auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0); 744 retVal.push_back(vals); 745 retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp)); 746 // Then get the positions and coordinates buffers. 747 const Level lvlRank = stt.getLvlRank(); 748 Level trailCOOLen = 0; 749 for (Level l = 0; l < lvlRank; l++) { 750 if (!stt.isUniqueLvl(l) && 751 (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) { 752 // A `(loose)compressed_nu` level marks the start of trailing COO 753 // start level. Since the target coordinate buffer used for trailing 754 // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA 755 // scheme, we cannot simply use the internal buffers. 756 trailCOOLen = lvlRank - l; 757 break; 758 } 759 if (stt.isWithPos(l)) { 760 auto poss = 761 genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l); 762 auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0); 763 auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1]; 764 retVal.push_back(poss); 765 retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp)); 766 } 767 if (stt.isWithCrd(l)) { 768 auto crds = 769 genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l); 770 auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0); 771 auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1]; 772 retVal.push_back(crds); 773 retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp)); 774 } 775 } 776 // Handle AoS vs. SoA mismatch for COO. 777 if (trailCOOLen != 0) { 778 uint64_t cooStartLvl = lvlRank - trailCOOLen; 779 assert(!stt.isUniqueLvl(cooStartLvl) && 780 (stt.isCompressedLvl(cooStartLvl) || 781 stt.isLooseCompressedLvl(cooStartLvl))); 782 // Positions. 783 auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), 784 cooStartLvl); 785 auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0); 786 auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1]; 787 retVal.push_back(poss); 788 retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp)); 789 // Coordinates, copied over with: 790 // for (i = 0; i < crdLen; i++) 791 // buf[i][0] = crd0[i]; buf[i][1] = crd1[i]; 792 auto buf = 793 genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]); 794 auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), 795 cooStartLvl); 796 auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), 797 cooStartLvl + 1); 798 auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0); 799 auto two = constantIndex(rewriter, loc, 2); 800 auto bufLen = rewriter.create<arith::MulIOp>(loc, crdLen, two); 801 Type indexType = rewriter.getIndexType(); 802 auto zero = constantZero(rewriter, loc, indexType); 803 auto one = constantOne(rewriter, loc, indexType); 804 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, zero, crdLen, one); 805 auto idx = forOp.getInductionVar(); 806 rewriter.setInsertionPointToStart(forOp.getBody()); 807 auto c0 = rewriter.create<memref::LoadOp>(loc, crds0, idx); 808 auto c1 = rewriter.create<memref::LoadOp>(loc, crds1, idx); 809 SmallVector<Value> args; 810 args.push_back(idx); 811 args.push_back(zero); 812 rewriter.create<memref::StoreOp>(loc, c0, buf, args); 813 args[1] = one; 814 rewriter.create<memref::StoreOp>(loc, c1, buf, args); 815 rewriter.setInsertionPointAfter(forOp); 816 auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1]; 817 retVal.push_back(buf); 818 retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp)); 819 } 820 // Converts MemRefs back to Tensors. 821 assert(retVal.size() + retLen.size() == op.getNumResults()); 822 for (unsigned i = 0, sz = retVal.size(); i < sz; i++) { 823 auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]); 824 retVal[i] = 825 rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor); 826 } 827 // Appends the actual memory length used in each buffer returned. 828 retVal.append(retLen.begin(), retLen.end()); 829 rewriter.replaceOp(op, retVal); 830 return success(); 831 } 832 }; 833 834 } // namespace 835 836 //===----------------------------------------------------------------------===// 837 // Sparse tensor type conversion into opaque pointer. 838 //===----------------------------------------------------------------------===// 839 840 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() { 841 addConversion([](Type type) { return type; }); 842 addConversion(convertSparseTensorTypes); 843 } 844 845 //===----------------------------------------------------------------------===// 846 // Public method for populating conversion rules. 847 //===----------------------------------------------------------------------===// 848 849 /// Populates the given patterns list with conversion rules required for 850 /// the sparsification of linear algebra operations. 851 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 852 RewritePatternSet &patterns) { 853 patterns 854 .add<SparseReturnConverter, SparseTensorLvlOpConverter, 855 SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter, 856 SparseTensorAllocConverter, SparseTensorEmptyConverter, 857 SparseTensorDeallocConverter, SparseTensorReorderCOOConverter, 858 SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter, 859 SparseTensorToValuesConverter, SparseNumberOfEntriesConverter, 860 SparseTensorLoadConverter, SparseTensorInsertConverter, 861 SparseTensorExpandConverter, SparseTensorCompressConverter, 862 SparseTensorAssembleConverter, SparseTensorDisassembleConverter>( 863 typeConverter, patterns.getContext()); 864 } 865