1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts sparse tensor primitives into calls into a runtime 10 // support library. Sparse tensor types are converted into opaque pointers 11 // to the underlying sparse storage schemes. The use of opaque pointers 12 // together with runtime support library keeps the conversion relatively 13 // simple, but at the expense of IR opacity, which obscures opportunities 14 // for subsequent optimization of the IR. An alternative is provided by 15 // the SparseTensorCodegen pass. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "CodegenUtils.h" 20 21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 23 #include "mlir/Dialect/Linalg/Utils/Utils.h" 24 #include "mlir/Dialect/MemRef/IR/MemRef.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 30 #include "mlir/Dialect/Tensor/IR/Tensor.h" 31 #include "mlir/Transforms/DialectConversion.h" 32 33 using namespace mlir; 34 using namespace mlir::sparse_tensor; 35 36 namespace { 37 38 //===----------------------------------------------------------------------===// 39 // Helper methods. 40 //===----------------------------------------------------------------------===// 41 42 /// Maps each sparse tensor type to an opaque pointer. 43 static std::optional<Type> convertSparseTensorTypes(Type type) { 44 if (getSparseTensorEncoding(type) != nullptr) 45 return LLVM::LLVMPointerType::get(type.getContext()); 46 return std::nullopt; 47 } 48 49 /// Replaces the `op` with a `CallOp` to the `getFunc()` function reference. 50 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op, 51 StringRef name, TypeRange resultType, 52 ValueRange operands, 53 EmitCInterface emitCInterface) { 54 auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands, 55 emitCInterface); 56 return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn, 57 operands); 58 } 59 60 /// Generates call to lookup a level-size. N.B., this only generates 61 /// the raw function call, and therefore (intentionally) does not perform 62 /// any dim<->lvl conversion or other logic. 63 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor, 64 uint64_t lvl) { 65 StringRef name = "sparseLvlSize"; 66 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)}; 67 Type iTp = builder.getIndexType(); 68 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 69 .getResult(0); 70 } 71 72 /// Generates call to lookup a dimension-size. N.B., this only generates 73 /// the raw function call, and therefore (intentionally) does not perform 74 /// any dim<->lvl conversion or other logic. 75 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor, 76 uint64_t dim) { 77 StringRef name = "sparseDimSize"; 78 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)}; 79 Type iTp = builder.getIndexType(); 80 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 81 .getResult(0); 82 } 83 84 /// Looks up a level-size by returning a statically-computed constant 85 /// (when possible), or by calling `genLvlSizeCall` (when dynamic). 86 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc, 87 SparseTensorType stt, Value tensor, 88 Level lvl) { 89 // Only sparse tensors have "levels" to query. 90 assert(stt.hasEncoding()); 91 // TODO: The following implementation only handles permutations; 92 // we'll need to generalize this to handle arbitrary AffineExpr. 93 // 94 // There's no need to assert `isPermutation` here: because 95 // `getDimPosition` checks that the expr isa `AffineDimExpr`, 96 // which is all we care about (for supporting permutations). 97 const Dimension dim = 98 stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl); 99 const Size sz = stt.getDynamicDimSize(dim); 100 if (!ShapedType::isDynamic(sz)) 101 return constantIndex(builder, loc, sz); 102 // If we cannot statically compute the size from the shape, then we 103 // must dynamically query it. (In principle we could also dynamically 104 // compute it, but since we already did so to construct the `tensor` 105 // in the first place, we might as well query rather than recompute.) 106 return genLvlSizeCall(builder, loc, tensor, lvl); 107 } 108 109 /// Looks up a dimension-size by returning a constant from the shape 110 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes 111 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes 112 /// of dense tensors). 113 static Value createOrFoldDimCall(OpBuilder &builder, Location loc, 114 SparseTensorType stt, Value tensor, 115 Dimension dim) { 116 const Size sz = stt.getDynamicDimSize(dim); 117 if (!ShapedType::isDynamic(sz)) 118 return constantIndex(builder, loc, sz); 119 if (stt.hasEncoding()) 120 return genDimSizeCall(builder, loc, tensor, dim); 121 return linalg::createOrFoldDimOp(builder, loc, tensor, dim); 122 } 123 124 /// Populates the array with the dimension-sizes of the given tensor. 125 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, 126 Value tensor, SmallVectorImpl<Value> &out) { 127 const Dimension dimRank = stt.getDimRank(); 128 out.clear(); 129 out.reserve(dimRank); 130 for (Dimension d = 0; d < dimRank; d++) 131 out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d)); 132 } 133 134 /// Returns an array with the dimension-sizes of the given tensor. 135 /// If the *tensor* parameters is null, the tensor type is assumed to have a 136 /// static shape. 137 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc, 138 SparseTensorType stt, 139 Value tensor = Value()) { 140 SmallVector<Value> out; 141 fillDimSizes(builder, loc, stt, tensor, out); 142 return out; 143 } 144 145 /// Generates an uninitialized buffer of the given size and type, 146 /// but returns it as type `memref<? x $tp>` (rather than as type 147 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, 148 /// this buffer must be explicitly deallocated by client. 149 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { 150 auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); 151 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); 152 } 153 154 /// Generates a temporary buffer for the level-types of the given encoding. 155 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, 156 SparseTensorType stt) { 157 SmallVector<Value> lvlTypes; 158 lvlTypes.reserve(stt.getLvlRank()); 159 for (const auto dlt : stt.getEncoding().getLvlTypes()) 160 lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt)); 161 return allocaBuffer(builder, loc, lvlTypes); 162 } 163 164 /// Extracts the bare (aligned) pointers that point to the tensor. 165 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc, 166 Value tensor) { 167 auto buf = genToMemref(builder, loc, tensor); 168 return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf); 169 } 170 171 /// Generates a temporary buffer for the level-types of the given encoding. 172 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc, 173 ValueRange lvlTensors, Value valTensor) { 174 SmallVector<Value> lvlBarePtrs; 175 lvlBarePtrs.reserve(lvlTensors.size() + 1); 176 // Passing in lvl buffer pointers. 177 for (const auto lvl : lvlTensors) 178 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl)); 179 180 // Passing in value buffer pointers. 181 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor)); 182 Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>( 183 loc, allocaBuffer(builder, loc, lvlBarePtrs)); 184 Value idxCast = 185 builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr); 186 return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder), 187 idxCast); 188 } 189 190 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`: 191 /// the "swiss army knife" method of the sparse runtime support library 192 /// for materializing sparse tensors into the computation. This abstraction 193 /// reduces the need for modifications when the API changes. 194 class NewCallParams final { 195 public: 196 /// Allocates the `ValueRange` for the `func::CallOp` parameters. 197 NewCallParams(OpBuilder &builder, Location loc) 198 : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {} 199 200 /// Initializes all static parameters (i.e., those which indicate 201 /// type-level information such as the encoding and sizes), generating 202 /// MLIR buffers as needed, and returning `this` for method chaining. 203 NewCallParams &genBuffers(SparseTensorType stt, 204 ArrayRef<Value> dimSizesValues, 205 Value dimSizesBuffer = Value()) { 206 assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank())); 207 // Sparsity annotations. 208 params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt); 209 // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers. 210 params[kParamDimSizes] = dimSizesBuffer 211 ? dimSizesBuffer 212 : allocaBuffer(builder, loc, dimSizesValues); 213 params[kParamLvlSizes] = 214 genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes], 215 params[kParamDim2Lvl], params[kParamLvl2Dim]); 216 // Secondary and primary types encoding. 217 const auto enc = stt.getEncoding(); 218 params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc); 219 params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc); 220 params[kParamValTp] = 221 constantPrimaryTypeEncoding(builder, loc, stt.getElementType()); 222 // Return `this` for method chaining. 223 return *this; 224 } 225 226 /// Checks whether all the static parameters have been initialized. 227 bool isInitialized() const { 228 for (unsigned i = 0; i < kNumStaticParams; ++i) 229 if (!params[i]) 230 return false; 231 return true; 232 } 233 234 /// Generates a function call, with the current static parameters 235 /// and the given dynamic arguments. 236 Value genNewCall(Action action, Value ptr = Value()) { 237 assert(isInitialized() && "Must initialize before genNewCall"); 238 StringRef name = "newSparseTensor"; 239 params[kParamAction] = constantAction(builder, loc, action); 240 params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp); 241 return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On) 242 .getResult(0); 243 } 244 245 private: 246 static constexpr unsigned kNumStaticParams = 8; 247 static constexpr unsigned kNumDynamicParams = 2; 248 static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams; 249 static constexpr unsigned kParamDimSizes = 0; 250 static constexpr unsigned kParamLvlSizes = 1; 251 static constexpr unsigned kParamLvlTypes = 2; 252 static constexpr unsigned kParamDim2Lvl = 3; 253 static constexpr unsigned kParamLvl2Dim = 4; 254 static constexpr unsigned kParamPosTp = 5; 255 static constexpr unsigned kParamCrdTp = 6; 256 static constexpr unsigned kParamValTp = 7; 257 static constexpr unsigned kParamAction = 8; 258 static constexpr unsigned kParamPtr = 9; 259 260 OpBuilder &builder; 261 Location loc; 262 Type pTp; 263 Value params[kNumParams]; 264 }; 265 266 /// Generates a call to obtain the values array. 267 static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp, 268 ValueRange ptr) { 269 SmallString<15> name{"sparseValues", 270 primaryTypeFunctionSuffix(tp.getElementType())}; 271 return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On) 272 .getResult(0); 273 } 274 275 //===----------------------------------------------------------------------===// 276 // Conversion rules. 277 //===----------------------------------------------------------------------===// 278 279 /// Sparse conversion rule for returns. 280 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 281 public: 282 using OpConversionPattern::OpConversionPattern; 283 LogicalResult 284 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 285 ConversionPatternRewriter &rewriter) const override { 286 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 287 return success(); 288 } 289 }; 290 291 /// Sparse conversion rule for accessing level-sizes. 292 class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> { 293 public: 294 using OpConversionPattern::OpConversionPattern; 295 LogicalResult 296 matchAndRewrite(LvlOp op, OpAdaptor adaptor, 297 ConversionPatternRewriter &rewriter) const override { 298 const auto stt = getSparseTensorType(op.getSource()); 299 // Only rewrite sparse DimOp. 300 if (!stt.hasEncoding()) 301 return failure(); 302 303 // Only rewrite DimOp with constant index. 304 std::optional<int64_t> lvl = op.getConstantLvlIndex(); 305 306 if (!lvl) 307 return failure(); 308 309 // By now, if the level size is constant, the operation should have already 310 // been folded by LvlOp's folder, so we generate the call unconditionally. 311 Value src = adaptor.getOperands()[0]; 312 rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl)); 313 return success(); 314 } 315 }; 316 317 /// Sparse conversion rule for trivial tensor casts. 318 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 319 public: 320 using OpConversionPattern::OpConversionPattern; 321 LogicalResult 322 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 323 ConversionPatternRewriter &rewriter) const override { 324 // Only rewrite identically annotated source/dest. 325 auto encDst = getSparseTensorEncoding(op.getType()); 326 auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 327 if (!encDst || encDst != encSrc) 328 return failure(); 329 rewriter.replaceOp(op, adaptor.getOperands()); 330 return success(); 331 } 332 }; 333 334 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> { 335 public: 336 using OpConversionPattern::OpConversionPattern; 337 LogicalResult 338 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor, 339 ConversionPatternRewriter &rewriter) const override { 340 // Simply fold the operation. 341 rewriter.replaceOp(op, adaptor.getSource()); 342 return success(); 343 } 344 }; 345 346 /// Sparse conversion rule for the new operator. 347 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 348 public: 349 using OpConversionPattern::OpConversionPattern; 350 LogicalResult 351 matchAndRewrite(NewOp op, OpAdaptor adaptor, 352 ConversionPatternRewriter &rewriter) const override { 353 Location loc = op.getLoc(); 354 const auto stt = getSparseTensorType(op); 355 if (!stt.hasEncoding()) 356 return failure(); 357 // Construct the `reader` opening method calls. 358 SmallVector<Value> dimShapesValues; 359 Value dimSizesBuffer; 360 Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0], 361 dimShapesValues, dimSizesBuffer); 362 // Use the `reader` to parse the file. 363 Value tensor = NewCallParams(rewriter, loc) 364 .genBuffers(stt, dimShapesValues, dimSizesBuffer) 365 .genNewCall(Action::kFromReader, reader); 366 // Free the memory for `reader`. 367 createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, 368 EmitCInterface::Off); 369 rewriter.replaceOp(op, tensor); 370 return success(); 371 } 372 }; 373 374 /// Sparse conversion rule for the alloc operator. 375 /// TODO(springerm): remove when bufferization.alloc_tensor is gone 376 class SparseTensorAllocConverter 377 : public OpConversionPattern<bufferization::AllocTensorOp> { 378 public: 379 using OpConversionPattern::OpConversionPattern; 380 LogicalResult 381 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, 382 ConversionPatternRewriter &rewriter) const override { 383 const auto stt = getSparseTensorType(op); 384 if (!stt.hasEncoding()) 385 return failure(); 386 if (op.getCopy()) 387 return rewriter.notifyMatchFailure(op, "alloc copy not implemented"); 388 // Gather all dimension sizes as SSA values. 389 Location loc = op.getLoc(); 390 const Dimension dimRank = stt.getDimRank(); 391 SmallVector<Value> dimSizes; 392 dimSizes.reserve(dimRank); 393 unsigned operandCtr = 0; 394 for (Dimension d = 0; d < dimRank; ++d) { 395 dimSizes.push_back( 396 stt.isDynamicDim(d) 397 ? adaptor.getOperands()[operandCtr++] 398 : constantIndex(rewriter, loc, op.getStaticSize(d))); 399 } 400 // Generate the call to construct empty tensor. The sizes are 401 // explicitly defined by the arguments to the alloc operator. 402 rewriter.replaceOp(op, NewCallParams(rewriter, loc) 403 .genBuffers(stt, dimSizes) 404 .genNewCall(Action::kEmpty)); 405 return success(); 406 } 407 }; 408 409 /// Sparse conversion rule for the empty tensor. 410 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { 411 public: 412 using OpConversionPattern::OpConversionPattern; 413 LogicalResult 414 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, 415 ConversionPatternRewriter &rewriter) const override { 416 Location loc = op.getLoc(); 417 const auto stt = getSparseTensorType(op); 418 if (!stt.hasEncoding()) 419 return failure(); 420 // Gather all dimension sizes as SSA values. 421 const Dimension dimRank = stt.getDimRank(); 422 SmallVector<Value> dimSizes; 423 dimSizes.reserve(dimRank); 424 auto shape = op.getType().getShape(); 425 unsigned operandCtr = 0; 426 for (Dimension d = 0; d < dimRank; ++d) { 427 dimSizes.push_back(stt.isDynamicDim(d) 428 ? adaptor.getOperands()[operandCtr++] 429 : constantIndex(rewriter, loc, shape[d])); 430 } 431 // Generate the call to construct empty tensor. The sizes are 432 // explicitly defined by the arguments to the alloc operator. 433 rewriter.replaceOp(op, NewCallParams(rewriter, loc) 434 .genBuffers(stt, dimSizes) 435 .genNewCall(Action::kEmpty)); 436 return success(); 437 } 438 }; 439 440 /// Sparse conversion rule for the convert operator. 441 class SparseTensorReorderCOOConverter 442 : public OpConversionPattern<ReorderCOOOp> { 443 public: 444 using OpConversionPattern::OpConversionPattern; 445 446 LogicalResult 447 matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor, 448 ConversionPatternRewriter &rewriter) const override { 449 const Location loc = op->getLoc(); 450 const auto srcTp = getSparseTensorType(op.getInputCoo()); 451 const auto dstTp = getSparseTensorType(op); 452 453 const Value src = adaptor.getInputCoo(); 454 455 NewCallParams params(rewriter, loc); 456 SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src); 457 rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizes) 458 .genNewCall(Action::kSortCOOInPlace, src)); 459 460 return success(); 461 } 462 }; 463 464 /// Sparse conversion rule for the dealloc operator. 465 class SparseTensorDeallocConverter 466 : public OpConversionPattern<bufferization::DeallocTensorOp> { 467 public: 468 using OpConversionPattern::OpConversionPattern; 469 LogicalResult 470 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, 471 ConversionPatternRewriter &rewriter) const override { 472 if (!getSparseTensorType(op.getTensor()).hasEncoding()) 473 return failure(); 474 StringRef name = "delSparseTensor"; 475 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 476 EmitCInterface::Off); 477 rewriter.eraseOp(op); 478 return success(); 479 } 480 }; 481 482 /// Sparse conversion rule for position accesses. 483 class SparseTensorToPositionsConverter 484 : public OpConversionPattern<ToPositionsOp> { 485 public: 486 using OpConversionPattern::OpConversionPattern; 487 LogicalResult 488 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor, 489 ConversionPatternRewriter &rewriter) const override { 490 Type resTp = op.getType(); 491 Type posTp = cast<ShapedType>(resTp).getElementType(); 492 SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)}; 493 Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel()); 494 replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl}, 495 EmitCInterface::On); 496 return success(); 497 } 498 }; 499 500 /// Sparse conversion rule for coordinate accesses. 501 class SparseTensorToCoordinatesConverter 502 : public OpConversionPattern<ToCoordinatesOp> { 503 public: 504 using OpConversionPattern::OpConversionPattern; 505 LogicalResult 506 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor, 507 ConversionPatternRewriter &rewriter) const override { 508 // TODO: use `SparseTensorType::getCrdType` instead. 509 Type resType = op.getType(); 510 const Type crdTp = cast<ShapedType>(resType).getElementType(); 511 SmallString<19> name{"sparseCoordinates", 512 overheadTypeFunctionSuffix(crdTp)}; 513 Location loc = op->getLoc(); 514 Value lvl = constantIndex(rewriter, loc, op.getLevel()); 515 516 // The function returns a MemRef without a layout. 517 MemRefType callRetType = get1DMemRefType(crdTp, false); 518 SmallVector<Value> operands{adaptor.getTensor(), lvl}; 519 auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType, 520 operands, EmitCInterface::On); 521 Value callRet = 522 rewriter.create<func::CallOp>(loc, callRetType, fn, operands) 523 .getResult(0); 524 525 // Cast the MemRef type to the type expected by the users, though these 526 // two types should be compatible at runtime. 527 if (resType != callRetType) 528 callRet = rewriter.create<memref::CastOp>(loc, resType, callRet); 529 rewriter.replaceOp(op, callRet); 530 531 return success(); 532 } 533 }; 534 535 /// Sparse conversion rule for value accesses. 536 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 537 public: 538 using OpConversionPattern::OpConversionPattern; 539 LogicalResult 540 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 541 ConversionPatternRewriter &rewriter) const override { 542 auto resType = cast<ShapedType>(op.getType()); 543 rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType, 544 adaptor.getOperands())); 545 return success(); 546 } 547 }; 548 549 /// Sparse conversion rule for number of entries operator. 550 class SparseNumberOfEntriesConverter 551 : public OpConversionPattern<NumberOfEntriesOp> { 552 public: 553 using OpConversionPattern::OpConversionPattern; 554 LogicalResult 555 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, 556 ConversionPatternRewriter &rewriter) const override { 557 Location loc = op.getLoc(); 558 // Query values array size for the actually stored values size. 559 Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType(); 560 auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType); 561 Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands()); 562 rewriter.replaceOpWithNewOp<memref::DimOp>(op, values, 563 constantIndex(rewriter, loc, 0)); 564 return success(); 565 } 566 }; 567 568 /// Sparse conversion rule for tensor rematerialization. 569 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 570 public: 571 using OpConversionPattern::OpConversionPattern; 572 LogicalResult 573 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 574 ConversionPatternRewriter &rewriter) const override { 575 if (op.getHasInserts()) { 576 // Finalize any pending insertions. 577 StringRef name = "endLexInsert"; 578 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 579 EmitCInterface::Off); 580 } 581 rewriter.replaceOp(op, adaptor.getOperands()); 582 return success(); 583 } 584 }; 585 586 /// Sparse conversion rule for the insertion operator. 587 class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> { 588 public: 589 using OpConversionPattern::OpConversionPattern; 590 LogicalResult 591 matchAndRewrite(InsertOp op, OpAdaptor adaptor, 592 ConversionPatternRewriter &rewriter) const override { 593 // Note that the current regime only allows for strict lexicographic 594 // coordinate order. All values are passed by reference through stack 595 // allocated memrefs. 596 Location loc = op->getLoc(); 597 const auto stt = getSparseTensorType(op.getTensor()); 598 const auto elemTp = stt.getElementType(); 599 const Level lvlRank = stt.getLvlRank(); 600 Value lvlCoords, vref; 601 { 602 OpBuilder::InsertionGuard guard(rewriter); 603 Operation *loop = op; 604 // Finds the outermost loop. 605 while (auto l = loop->getParentOfType<LoopLikeOpInterface>()) 606 loop = l; 607 608 if (llvm::isa<LoopLikeOpInterface>(loop)) { 609 // Hoists alloca outside the loop to avoid stack overflow. 610 rewriter.setInsertionPoint(loop); 611 } 612 lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 613 vref = genAllocaScalar(rewriter, loc, elemTp); 614 } 615 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords()); 616 rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref); 617 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 618 createFuncCall(rewriter, loc, name, {}, 619 {adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On); 620 rewriter.replaceOp(op, adaptor.getTensor()); 621 return success(); 622 } 623 }; 624 625 /// Sparse conversion rule for the expand operator. 626 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 627 public: 628 using OpConversionPattern::OpConversionPattern; 629 LogicalResult 630 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 631 ConversionPatternRewriter &rewriter) const override { 632 Location loc = op->getLoc(); 633 const auto srcTp = getSparseTensorType(op.getTensor()); 634 Type eltType = srcTp.getElementType(); 635 Type boolType = rewriter.getIntegerType(1); 636 Type idxType = rewriter.getIndexType(); 637 // All initialization should be done on entry of the loop nest. 638 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); 639 // Get the cardinality of valid coordinates for the innermost level. 640 Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(), 641 srcTp.getLvlRank() - 1); 642 // Allocate temporary buffers for values, filled-switch, and coordinates. 643 // We do not use stack buffers for this, since the expanded size may 644 // be rather large (as it envelops a single expanded dense dimension). 645 Value values = genAlloc(rewriter, loc, sz, eltType); 646 Value filled = genAlloc(rewriter, loc, sz, boolType); 647 Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType); 648 Value zero = constantZero(rewriter, loc, idxType); 649 // Reset the values/filled-switch to all-zero/false. Note that this 650 // introduces an O(N) operation into the computation, but this reset 651 // operation is amortized over the innermost loops for the access 652 // pattern expansion. As noted in the operation doc, we would like 653 // to amortize this setup cost even between kernels. 654 rewriter.create<linalg::FillOp>( 655 loc, ValueRange{constantZero(rewriter, loc, eltType)}, 656 ValueRange{values}); 657 rewriter.create<linalg::FillOp>( 658 loc, ValueRange{constantZero(rewriter, loc, boolType)}, 659 ValueRange{filled}); 660 // Replace expansion op with these buffers and initial coordinate. 661 assert(op.getNumResults() == 4); 662 rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero}); 663 return success(); 664 } 665 }; 666 667 /// Sparse conversion rule for the compress operator. 668 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 669 public: 670 using OpConversionPattern::OpConversionPattern; 671 LogicalResult 672 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 673 ConversionPatternRewriter &rewriter) const override { 674 Location loc = op->getLoc(); 675 // Note that this method call resets the values/filled-switch back to 676 // all-zero/false by only iterating over the set elements, so the 677 // complexity remains proportional to the sparsity of the expanded 678 // access pattern. 679 Value values = adaptor.getValues(); 680 Value filled = adaptor.getFilled(); 681 Value added = adaptor.getAdded(); 682 Value count = adaptor.getCount(); 683 Value tensor = adaptor.getTensor(); 684 const auto stt = getSparseTensorType(op.getTensor()); 685 const Type elemTp = stt.getElementType(); 686 const Level lvlRank = stt.getLvlRank(); 687 auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 688 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords()); 689 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 690 createFuncCall(rewriter, loc, name, {}, 691 {tensor, lvlCoords, values, filled, added, count}, 692 EmitCInterface::On); 693 rewriter.replaceOp(op, adaptor.getTensor()); 694 // Deallocate the buffers on exit of the loop nest. 695 Operation *parent = getTop(op); 696 rewriter.setInsertionPointAfter(parent); 697 rewriter.create<memref::DeallocOp>(loc, values); 698 rewriter.create<memref::DeallocOp>(loc, filled); 699 rewriter.create<memref::DeallocOp>(loc, added); 700 return success(); 701 } 702 }; 703 704 /// Sparse conversion rule for the sparse_tensor.pack operator. 705 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> { 706 public: 707 using OpConversionPattern::OpConversionPattern; 708 LogicalResult 709 matchAndRewrite(AssembleOp op, OpAdaptor adaptor, 710 ConversionPatternRewriter &rewriter) const override { 711 const Location loc = op->getLoc(); 712 const auto dstTp = getSparseTensorType(op.getResult()); 713 // AssembleOps always returns a static shaped tensor result. 714 assert(dstTp.hasStaticDimShape()); 715 SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp); 716 Value dst = 717 NewCallParams(rewriter, loc) 718 .genBuffers(dstTp.withoutDimToLvl(), dimSizes) 719 .genNewCall(Action::kPack, 720 genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(), 721 adaptor.getValues())); 722 rewriter.replaceOp(op, dst); 723 return success(); 724 } 725 }; 726 727 } // namespace 728 729 //===----------------------------------------------------------------------===// 730 // Sparse tensor type conversion into opaque pointer. 731 //===----------------------------------------------------------------------===// 732 733 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() { 734 addConversion([](Type type) { return type; }); 735 addConversion(convertSparseTensorTypes); 736 } 737 738 //===----------------------------------------------------------------------===// 739 // Public method for populating conversion rules. 740 //===----------------------------------------------------------------------===// 741 742 /// Populates the given patterns list with conversion rules required for 743 /// the sparsification of linear algebra operations. 744 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 745 RewritePatternSet &patterns) { 746 patterns 747 .add<SparseReturnConverter, SparseTensorLvlOpConverter, 748 SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter, 749 SparseTensorAllocConverter, SparseTensorEmptyConverter, 750 SparseTensorDeallocConverter, SparseTensorReorderCOOConverter, 751 SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter, 752 SparseTensorToValuesConverter, SparseNumberOfEntriesConverter, 753 SparseTensorLoadConverter, SparseTensorInsertConverter, 754 SparseTensorExpandConverter, SparseTensorCompressConverter, 755 SparseTensorAssembleConverter>(typeConverter, patterns.getContext()); 756 } 757