1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts sparse tensor primitives into calls into a runtime 10 // support library. Sparse tensor types are converted into opaque pointers 11 // to the underlying sparse storage schemes. The use of opaque pointers 12 // together with runtime support library keeps the conversion relatively 13 // simple, but at the expense of IR opacity, which obscures opportunities 14 // for subsequent optimization of the IR. An alternative is provided by 15 // the SparseTensorCodegen pass. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "CodegenUtils.h" 20 21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 23 #include "mlir/Dialect/Linalg/Utils/Utils.h" 24 #include "mlir/Dialect/MemRef/IR/MemRef.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 30 #include "mlir/Dialect/Tensor/IR/Tensor.h" 31 #include "mlir/Transforms/DialectConversion.h" 32 33 using namespace mlir; 34 using namespace mlir::sparse_tensor; 35 36 namespace { 37 38 //===----------------------------------------------------------------------===// 39 // Helper methods. 40 //===----------------------------------------------------------------------===// 41 42 /// Maps each sparse tensor type to an opaque pointer. 43 static std::optional<Type> convertSparseTensorTypes(Type type) { 44 if (getSparseTensorEncoding(type) != nullptr) 45 return LLVM::LLVMPointerType::get(type.getContext()); 46 return std::nullopt; 47 } 48 49 /// Replaces the `op` with a `CallOp` to the `getFunc()` function reference. 50 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op, 51 StringRef name, TypeRange resultType, 52 ValueRange operands, 53 EmitCInterface emitCInterface) { 54 auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands, 55 emitCInterface); 56 return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn, 57 operands); 58 } 59 60 /// Generates call to lookup a level-size. N.B., this only generates 61 /// the raw function call, and therefore (intentionally) does not perform 62 /// any dim<->lvl conversion or other logic. 63 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor, 64 uint64_t lvl) { 65 StringRef name = "sparseLvlSize"; 66 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)}; 67 Type iTp = builder.getIndexType(); 68 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 69 .getResult(0); 70 } 71 72 /// Generates call to lookup a dimension-size. N.B., this only generates 73 /// the raw function call, and therefore (intentionally) does not perform 74 /// any dim<->lvl conversion or other logic. 75 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor, 76 uint64_t dim) { 77 StringRef name = "sparseDimSize"; 78 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)}; 79 Type iTp = builder.getIndexType(); 80 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 81 .getResult(0); 82 } 83 84 /// Looks up a level-size by returning a statically-computed constant 85 /// (when possible), or by calling `genLvlSizeCall` (when dynamic). 86 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc, 87 SparseTensorType stt, Value tensor, 88 Level lvl) { 89 // Only sparse tensors have "levels" to query. 90 assert(stt.hasEncoding()); 91 // TODO: The following implementation only handles permutations; 92 // we'll need to generalize this to handle arbitrary AffineExpr. 93 // 94 // There's no need to assert `isPermutation` here: because 95 // `getDimPosition` checks that the expr isa `AffineDimExpr`, 96 // which is all we care about (for supporting permutations). 97 const Dimension dim = 98 stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl); 99 if (const auto sz = stt.getStaticDimSize(dim)) 100 return constantIndex(builder, loc, *sz); 101 // If we cannot statically compute the size from the shape, then we 102 // must dynamically query it. (In principle we could also dynamically 103 // compute it, but since we already did so to construct the `tensor` 104 // in the first place, we might as well query rather than recompute.) 105 return genLvlSizeCall(builder, loc, tensor, lvl); 106 } 107 108 /// Looks up a dimension-size by returning a constant from the shape 109 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes 110 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes 111 /// of dense tensors). 112 static Value createOrFoldDimCall(OpBuilder &builder, Location loc, 113 SparseTensorType stt, Value tensor, 114 Dimension dim) { 115 if (const auto sz = stt.getStaticDimSize(dim)) 116 return constantIndex(builder, loc, *sz); 117 if (stt.hasEncoding()) 118 return genDimSizeCall(builder, loc, tensor, dim); 119 return linalg::createOrFoldDimOp(builder, loc, tensor, dim); 120 } 121 122 /// Populates the array with the dimension-sizes of the given tensor. 123 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, 124 Value tensor, SmallVectorImpl<Value> &out) { 125 const Dimension dimRank = stt.getDimRank(); 126 out.clear(); 127 out.reserve(dimRank); 128 for (Dimension d = 0; d < dimRank; d++) 129 out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d)); 130 } 131 132 /// Returns an array with the dimension-sizes of the given tensor. 133 /// If the *tensor* parameters is null, the tensor type is assumed to have a 134 /// static shape. 135 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc, 136 SparseTensorType stt, 137 Value tensor = Value()) { 138 SmallVector<Value> out; 139 fillDimSizes(builder, loc, stt, tensor, out); 140 return out; 141 } 142 143 /// Generates an uninitialized buffer of the given size and type, 144 /// but returns it as type `memref<? x $tp>` (rather than as type 145 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, 146 /// this buffer must be explicitly deallocated by client. 147 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { 148 auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); 149 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); 150 } 151 152 /// Generates a temporary buffer for the level-types of the given encoding. 153 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, 154 SparseTensorType stt) { 155 SmallVector<Value> lvlTypes; 156 lvlTypes.reserve(stt.getLvlRank()); 157 for (const auto dlt : stt.getEncoding().getLvlTypes()) 158 lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt)); 159 return allocaBuffer(builder, loc, lvlTypes); 160 } 161 162 /// Extracts the bare (aligned) pointers that point to the tensor. 163 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc, 164 Value tensor) { 165 auto buf = genToMemref(builder, loc, tensor); 166 return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf); 167 } 168 169 /// Generates a temporary buffer for the level-types of the given encoding. 170 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc, 171 ValueRange lvlTensors, Value valTensor) { 172 SmallVector<Value> lvlBarePtrs; 173 lvlBarePtrs.reserve(lvlTensors.size() + 1); 174 // Passing in lvl buffer pointers. 175 for (const auto lvl : lvlTensors) 176 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl)); 177 178 // Passing in value buffer pointers. 179 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor)); 180 Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>( 181 loc, allocaBuffer(builder, loc, lvlBarePtrs)); 182 Value idxCast = 183 builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr); 184 return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder), 185 idxCast); 186 } 187 188 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`: 189 /// the "swiss army knife" method of the sparse runtime support library 190 /// for materializing sparse tensors into the computation. This abstraction 191 /// reduces the need for modifications when the API changes. 192 class NewCallParams final { 193 public: 194 /// Allocates the `ValueRange` for the `func::CallOp` parameters. 195 NewCallParams(OpBuilder &builder, Location loc) 196 : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {} 197 198 /// Initializes all static parameters (i.e., those which indicate 199 /// type-level information such as the encoding and sizes), generating 200 /// MLIR buffers as needed, and returning `this` for method chaining. 201 NewCallParams &genBuffers(SparseTensorType stt, 202 ArrayRef<Value> dimSizesValues, 203 Value dimSizesBuffer = Value()) { 204 assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank())); 205 // Sparsity annotations. 206 params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt); 207 // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers. 208 params[kParamDimSizes] = dimSizesBuffer 209 ? dimSizesBuffer 210 : allocaBuffer(builder, loc, dimSizesValues); 211 params[kParamLvlSizes] = 212 genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes], 213 params[kParamDim2Lvl], params[kParamLvl2Dim]); 214 // Secondary and primary types encoding. 215 const auto enc = stt.getEncoding(); 216 params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc); 217 params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc); 218 params[kParamValTp] = 219 constantPrimaryTypeEncoding(builder, loc, stt.getElementType()); 220 // Return `this` for method chaining. 221 return *this; 222 } 223 224 /// Checks whether all the static parameters have been initialized. 225 bool isInitialized() const { 226 for (unsigned i = 0; i < kNumStaticParams; ++i) 227 if (!params[i]) 228 return false; 229 return true; 230 } 231 232 /// Generates a function call, with the current static parameters 233 /// and the given dynamic arguments. 234 Value genNewCall(Action action, Value ptr = Value()) { 235 assert(isInitialized() && "Must initialize before genNewCall"); 236 StringRef name = "newSparseTensor"; 237 params[kParamAction] = constantAction(builder, loc, action); 238 params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp); 239 return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On) 240 .getResult(0); 241 } 242 243 private: 244 static constexpr unsigned kNumStaticParams = 8; 245 static constexpr unsigned kNumDynamicParams = 2; 246 static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams; 247 static constexpr unsigned kParamDimSizes = 0; 248 static constexpr unsigned kParamLvlSizes = 1; 249 static constexpr unsigned kParamLvlTypes = 2; 250 static constexpr unsigned kParamDim2Lvl = 3; 251 static constexpr unsigned kParamLvl2Dim = 4; 252 static constexpr unsigned kParamPosTp = 5; 253 static constexpr unsigned kParamCrdTp = 6; 254 static constexpr unsigned kParamValTp = 7; 255 static constexpr unsigned kParamAction = 8; 256 static constexpr unsigned kParamPtr = 9; 257 258 OpBuilder &builder; 259 Location loc; 260 Type pTp; 261 Value params[kNumParams]; 262 }; 263 264 /// Generates a call to obtain the values array. 265 static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp, 266 ValueRange ptr) { 267 SmallString<15> name{"sparseValues", 268 primaryTypeFunctionSuffix(tp.getElementType())}; 269 return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On) 270 .getResult(0); 271 } 272 273 //===----------------------------------------------------------------------===// 274 // Conversion rules. 275 //===----------------------------------------------------------------------===// 276 277 /// Sparse conversion rule for returns. 278 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 279 public: 280 using OpConversionPattern::OpConversionPattern; 281 LogicalResult 282 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 283 ConversionPatternRewriter &rewriter) const override { 284 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 285 return success(); 286 } 287 }; 288 289 /// Sparse conversion rule for accessing level-sizes. 290 class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> { 291 public: 292 using OpConversionPattern::OpConversionPattern; 293 LogicalResult 294 matchAndRewrite(LvlOp op, OpAdaptor adaptor, 295 ConversionPatternRewriter &rewriter) const override { 296 const auto stt = getSparseTensorType(op.getSource()); 297 // Only rewrite sparse DimOp. 298 if (!stt.hasEncoding()) 299 return failure(); 300 301 // Only rewrite DimOp with constant index. 302 std::optional<int64_t> lvl = op.getConstantLvlIndex(); 303 304 if (!lvl) 305 return failure(); 306 307 // By now, if the level size is constant, the operation should have already 308 // been folded by LvlOp's folder, so we generate the call unconditionally. 309 Value src = adaptor.getOperands()[0]; 310 rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl)); 311 return success(); 312 } 313 }; 314 315 /// Sparse conversion rule for trivial tensor casts. 316 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 317 public: 318 using OpConversionPattern::OpConversionPattern; 319 LogicalResult 320 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 321 ConversionPatternRewriter &rewriter) const override { 322 // Only rewrite identically annotated source/dest. 323 auto encDst = getSparseTensorEncoding(op.getType()); 324 auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 325 if (!encDst || encDst != encSrc) 326 return failure(); 327 rewriter.replaceOp(op, adaptor.getOperands()); 328 return success(); 329 } 330 }; 331 332 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> { 333 public: 334 using OpConversionPattern::OpConversionPattern; 335 LogicalResult 336 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor, 337 ConversionPatternRewriter &rewriter) const override { 338 // Simply fold the operation. 339 rewriter.replaceOp(op, adaptor.getSource()); 340 return success(); 341 } 342 }; 343 344 /// Sparse conversion rule for the new operator. 345 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 346 public: 347 using OpConversionPattern::OpConversionPattern; 348 LogicalResult 349 matchAndRewrite(NewOp op, OpAdaptor adaptor, 350 ConversionPatternRewriter &rewriter) const override { 351 Location loc = op.getLoc(); 352 const auto stt = getSparseTensorType(op); 353 if (!stt.hasEncoding()) 354 return failure(); 355 // Construct the `reader` opening method calls. 356 SmallVector<Value> dimShapesValues; 357 Value dimSizesBuffer; 358 Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0], 359 dimShapesValues, dimSizesBuffer); 360 // Use the `reader` to parse the file. 361 Value tensor = NewCallParams(rewriter, loc) 362 .genBuffers(stt, dimShapesValues, dimSizesBuffer) 363 .genNewCall(Action::kFromReader, reader); 364 // Free the memory for `reader`. 365 createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, 366 EmitCInterface::Off); 367 rewriter.replaceOp(op, tensor); 368 return success(); 369 } 370 }; 371 372 /// Sparse conversion rule for the alloc operator. 373 /// TODO(springerm): remove when bufferization.alloc_tensor is gone 374 class SparseTensorAllocConverter 375 : public OpConversionPattern<bufferization::AllocTensorOp> { 376 public: 377 using OpConversionPattern::OpConversionPattern; 378 LogicalResult 379 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, 380 ConversionPatternRewriter &rewriter) const override { 381 if (op.getCopy()) 382 return rewriter.notifyMatchFailure(op, 383 "sparse tensor copy not implemented"); 384 Location loc = op.getLoc(); 385 const auto stt = getSparseTensorType(op); 386 if (!stt.hasEncoding()) 387 return failure(); 388 // Gather all dimension sizes as SSA values. 389 const Dimension dimRank = stt.getDimRank(); 390 SmallVector<Value> dimSizes; 391 dimSizes.reserve(dimRank); 392 unsigned operandCtr = 0; 393 for (Dimension d = 0; d < dimRank; ++d) { 394 dimSizes.push_back( 395 stt.isDynamicDim(d) 396 ? adaptor.getOperands()[operandCtr++] 397 : constantIndex(rewriter, loc, op.getStaticSize(d))); 398 } 399 // Generate the call to construct empty tensor. The sizes are 400 // explicitly defined by the arguments to the alloc operator. 401 rewriter.replaceOp(op, NewCallParams(rewriter, loc) 402 .genBuffers(stt, dimSizes) 403 .genNewCall(Action::kEmpty)); 404 return success(); 405 } 406 }; 407 408 /// Sparse conversion rule for the empty tensor. 409 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { 410 public: 411 using OpConversionPattern::OpConversionPattern; 412 LogicalResult 413 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, 414 ConversionPatternRewriter &rewriter) const override { 415 Location loc = op.getLoc(); 416 const auto stt = getSparseTensorType(op); 417 if (!stt.hasEncoding()) 418 return failure(); 419 // Gather all dimension sizes as SSA values. 420 const Dimension dimRank = stt.getDimRank(); 421 SmallVector<Value> dimSizes; 422 dimSizes.reserve(dimRank); 423 auto shape = op.getType().getShape(); 424 unsigned operandCtr = 0; 425 for (Dimension d = 0; d < dimRank; ++d) { 426 dimSizes.push_back(stt.isDynamicDim(d) 427 ? adaptor.getOperands()[operandCtr++] 428 : constantIndex(rewriter, loc, shape[d])); 429 } 430 // Generate the call to construct empty tensor. The sizes are 431 // explicitly defined by the arguments to the alloc operator. 432 rewriter.replaceOp(op, NewCallParams(rewriter, loc) 433 .genBuffers(stt, dimSizes) 434 .genNewCall(Action::kEmpty)); 435 return success(); 436 } 437 }; 438 439 /// Sparse conversion rule for the convert operator. 440 class SparseTensorReorderCOOConverter 441 : public OpConversionPattern<ReorderCOOOp> { 442 public: 443 using OpConversionPattern::OpConversionPattern; 444 445 LogicalResult 446 matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor, 447 ConversionPatternRewriter &rewriter) const override { 448 const Location loc = op->getLoc(); 449 const auto srcTp = getSparseTensorType(op.getInputCoo()); 450 const auto dstTp = getSparseTensorType(op); 451 452 const Value src = adaptor.getInputCoo(); 453 454 NewCallParams params(rewriter, loc); 455 SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src); 456 rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizes) 457 .genNewCall(Action::kSortCOOInPlace, src)); 458 459 return success(); 460 } 461 }; 462 463 /// Sparse conversion rule for the dealloc operator. 464 class SparseTensorDeallocConverter 465 : public OpConversionPattern<bufferization::DeallocTensorOp> { 466 public: 467 using OpConversionPattern::OpConversionPattern; 468 LogicalResult 469 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, 470 ConversionPatternRewriter &rewriter) const override { 471 if (!getSparseTensorType(op.getTensor()).hasEncoding()) 472 return failure(); 473 StringRef name = "delSparseTensor"; 474 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 475 EmitCInterface::Off); 476 rewriter.eraseOp(op); 477 return success(); 478 } 479 }; 480 481 /// Sparse conversion rule for position accesses. 482 class SparseTensorToPositionsConverter 483 : public OpConversionPattern<ToPositionsOp> { 484 public: 485 using OpConversionPattern::OpConversionPattern; 486 LogicalResult 487 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor, 488 ConversionPatternRewriter &rewriter) const override { 489 Type resTp = op.getType(); 490 Type posTp = cast<ShapedType>(resTp).getElementType(); 491 SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)}; 492 Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel()); 493 replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl}, 494 EmitCInterface::On); 495 return success(); 496 } 497 }; 498 499 /// Sparse conversion rule for coordinate accesses. 500 class SparseTensorToCoordinatesConverter 501 : public OpConversionPattern<ToCoordinatesOp> { 502 public: 503 using OpConversionPattern::OpConversionPattern; 504 LogicalResult 505 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor, 506 ConversionPatternRewriter &rewriter) const override { 507 // TODO: use `SparseTensorType::getCrdType` instead. 508 Type resType = op.getType(); 509 const Type crdTp = cast<ShapedType>(resType).getElementType(); 510 SmallString<19> name{"sparseCoordinates", 511 overheadTypeFunctionSuffix(crdTp)}; 512 Location loc = op->getLoc(); 513 Value lvl = constantIndex(rewriter, loc, op.getLevel()); 514 515 // The function returns a MemRef without a layout. 516 MemRefType callRetType = get1DMemRefType(crdTp, false); 517 SmallVector<Value> operands{adaptor.getTensor(), lvl}; 518 auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType, 519 operands, EmitCInterface::On); 520 Value callRet = 521 rewriter.create<func::CallOp>(loc, callRetType, fn, operands) 522 .getResult(0); 523 524 // Cast the MemRef type to the type expected by the users, though these 525 // two types should be compatible at runtime. 526 if (resType != callRetType) 527 callRet = rewriter.create<memref::CastOp>(loc, resType, callRet); 528 rewriter.replaceOp(op, callRet); 529 530 return success(); 531 } 532 }; 533 534 /// Sparse conversion rule for value accesses. 535 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 536 public: 537 using OpConversionPattern::OpConversionPattern; 538 LogicalResult 539 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 540 ConversionPatternRewriter &rewriter) const override { 541 auto resType = cast<ShapedType>(op.getType()); 542 rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType, 543 adaptor.getOperands())); 544 return success(); 545 } 546 }; 547 548 /// Sparse conversion rule for number of entries operator. 549 class SparseNumberOfEntriesConverter 550 : public OpConversionPattern<NumberOfEntriesOp> { 551 public: 552 using OpConversionPattern::OpConversionPattern; 553 LogicalResult 554 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, 555 ConversionPatternRewriter &rewriter) const override { 556 Location loc = op.getLoc(); 557 // Query values array size for the actually stored values size. 558 Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType(); 559 auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType); 560 Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands()); 561 rewriter.replaceOpWithNewOp<memref::DimOp>(op, values, 562 constantIndex(rewriter, loc, 0)); 563 return success(); 564 } 565 }; 566 567 /// Sparse conversion rule for tensor rematerialization. 568 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 569 public: 570 using OpConversionPattern::OpConversionPattern; 571 LogicalResult 572 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 573 ConversionPatternRewriter &rewriter) const override { 574 if (op.getHasInserts()) { 575 // Finalize any pending insertions. 576 StringRef name = "endLexInsert"; 577 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 578 EmitCInterface::Off); 579 } 580 rewriter.replaceOp(op, adaptor.getOperands()); 581 return success(); 582 } 583 }; 584 585 /// Sparse conversion rule for the insertion operator. 586 class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> { 587 public: 588 using OpConversionPattern::OpConversionPattern; 589 LogicalResult 590 matchAndRewrite(InsertOp op, OpAdaptor adaptor, 591 ConversionPatternRewriter &rewriter) const override { 592 // Note that the current regime only allows for strict lexicographic 593 // coordinate order. All values are passed by reference through stack 594 // allocated memrefs. 595 Location loc = op->getLoc(); 596 const auto stt = getSparseTensorType(op.getTensor()); 597 const auto elemTp = stt.getElementType(); 598 const Level lvlRank = stt.getLvlRank(); 599 Value lvlCoords, vref; 600 { 601 OpBuilder::InsertionGuard guard(rewriter); 602 Operation *loop = op; 603 // Finds the outermost loop. 604 while (auto l = loop->getParentOfType<LoopLikeOpInterface>()) 605 loop = l; 606 607 if (llvm::isa<LoopLikeOpInterface>(loop)) { 608 // Hoists alloca outside the loop to avoid stack overflow. 609 rewriter.setInsertionPoint(loop); 610 } 611 lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 612 vref = genAllocaScalar(rewriter, loc, elemTp); 613 } 614 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords()); 615 rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref); 616 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 617 createFuncCall(rewriter, loc, name, {}, 618 {adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On); 619 rewriter.replaceOp(op, adaptor.getTensor()); 620 return success(); 621 } 622 }; 623 624 /// Sparse conversion rule for the expand operator. 625 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 626 public: 627 using OpConversionPattern::OpConversionPattern; 628 LogicalResult 629 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 630 ConversionPatternRewriter &rewriter) const override { 631 Location loc = op->getLoc(); 632 const auto srcTp = getSparseTensorType(op.getTensor()); 633 Type eltType = srcTp.getElementType(); 634 Type boolType = rewriter.getIntegerType(1); 635 Type idxType = rewriter.getIndexType(); 636 // All initialization should be done on entry of the loop nest. 637 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); 638 // Get the cardinality of valid coordinates for the innermost level. 639 Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(), 640 srcTp.getLvlRank() - 1); 641 // Allocate temporary buffers for values, filled-switch, and coordinates. 642 // We do not use stack buffers for this, since the expanded size may 643 // be rather large (as it envelops a single expanded dense dimension). 644 Value values = genAlloc(rewriter, loc, sz, eltType); 645 Value filled = genAlloc(rewriter, loc, sz, boolType); 646 Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType); 647 Value zero = constantZero(rewriter, loc, idxType); 648 // Reset the values/filled-switch to all-zero/false. Note that this 649 // introduces an O(N) operation into the computation, but this reset 650 // operation is amortized over the innermost loops for the access 651 // pattern expansion. As noted in the operation doc, we would like 652 // to amortize this setup cost even between kernels. 653 rewriter.create<linalg::FillOp>( 654 loc, ValueRange{constantZero(rewriter, loc, eltType)}, 655 ValueRange{values}); 656 rewriter.create<linalg::FillOp>( 657 loc, ValueRange{constantZero(rewriter, loc, boolType)}, 658 ValueRange{filled}); 659 // Replace expansion op with these buffers and initial coordinate. 660 assert(op.getNumResults() == 4); 661 rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero}); 662 return success(); 663 } 664 }; 665 666 /// Sparse conversion rule for the compress operator. 667 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 668 public: 669 using OpConversionPattern::OpConversionPattern; 670 LogicalResult 671 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 672 ConversionPatternRewriter &rewriter) const override { 673 Location loc = op->getLoc(); 674 // Note that this method call resets the values/filled-switch back to 675 // all-zero/false by only iterating over the set elements, so the 676 // complexity remains proportional to the sparsity of the expanded 677 // access pattern. 678 Value values = adaptor.getValues(); 679 Value filled = adaptor.getFilled(); 680 Value added = adaptor.getAdded(); 681 Value count = adaptor.getCount(); 682 Value tensor = adaptor.getTensor(); 683 const auto stt = getSparseTensorType(op.getTensor()); 684 const Type elemTp = stt.getElementType(); 685 const Level lvlRank = stt.getLvlRank(); 686 auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 687 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords()); 688 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 689 createFuncCall(rewriter, loc, name, {}, 690 {tensor, lvlCoords, values, filled, added, count}, 691 EmitCInterface::On); 692 rewriter.replaceOp(op, adaptor.getTensor()); 693 // Deallocate the buffers on exit of the loop nest. 694 Operation *parent = getTop(op); 695 rewriter.setInsertionPointAfter(parent); 696 rewriter.create<memref::DeallocOp>(loc, values); 697 rewriter.create<memref::DeallocOp>(loc, filled); 698 rewriter.create<memref::DeallocOp>(loc, added); 699 return success(); 700 } 701 }; 702 703 /// Sparse conversion rule for the sparse_tensor.pack operator. 704 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> { 705 public: 706 using OpConversionPattern::OpConversionPattern; 707 LogicalResult 708 matchAndRewrite(AssembleOp op, OpAdaptor adaptor, 709 ConversionPatternRewriter &rewriter) const override { 710 const Location loc = op->getLoc(); 711 const auto dstTp = getSparseTensorType(op.getResult()); 712 // AssembleOps always returns a static shaped tensor result. 713 assert(dstTp.hasStaticDimShape()); 714 SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp); 715 Value dst = 716 NewCallParams(rewriter, loc) 717 .genBuffers(dstTp.withoutDimToLvl(), dimSizes) 718 .genNewCall(Action::kPack, 719 genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(), 720 adaptor.getValues())); 721 rewriter.replaceOp(op, dst); 722 return success(); 723 } 724 }; 725 726 } // namespace 727 728 //===----------------------------------------------------------------------===// 729 // Sparse tensor type conversion into opaque pointer. 730 //===----------------------------------------------------------------------===// 731 732 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() { 733 addConversion([](Type type) { return type; }); 734 addConversion(convertSparseTensorTypes); 735 } 736 737 //===----------------------------------------------------------------------===// 738 // Public method for populating conversion rules. 739 //===----------------------------------------------------------------------===// 740 741 /// Populates the given patterns list with conversion rules required for 742 /// the sparsification of linear algebra operations. 743 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 744 RewritePatternSet &patterns) { 745 patterns 746 .add<SparseReturnConverter, SparseTensorLvlOpConverter, 747 SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter, 748 SparseTensorAllocConverter, SparseTensorEmptyConverter, 749 SparseTensorDeallocConverter, SparseTensorReorderCOOConverter, 750 SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter, 751 SparseTensorToValuesConverter, SparseNumberOfEntriesConverter, 752 SparseTensorLoadConverter, SparseTensorInsertConverter, 753 SparseTensorExpandConverter, SparseTensorCompressConverter, 754 SparseTensorAssembleConverter>(typeConverter, patterns.getContext()); 755 } 756