1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts sparse tensor primitives into calls into a runtime 10 // support library. Sparse tensor types are converted into opaque pointers 11 // to the underlying sparse storage schemes. The use of opaque pointers 12 // together with runtime support library keeps the conversion relatively 13 // simple, but at the expense of IR opacity, which obscures opportunities 14 // for subsequent optimization of the IR. An alternative is provided by 15 // the SparseTensorCodegen pass. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "CodegenUtils.h" 20 21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 23 #include "mlir/Dialect/Linalg/Utils/Utils.h" 24 #include "mlir/Dialect/MemRef/IR/MemRef.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 30 #include "mlir/Dialect/Tensor/IR/Tensor.h" 31 #include "mlir/Transforms/DialectConversion.h" 32 33 using namespace mlir; 34 using namespace mlir::sparse_tensor; 35 36 namespace { 37 38 //===----------------------------------------------------------------------===// 39 // Helper methods. 40 //===----------------------------------------------------------------------===// 41 42 /// Maps each sparse tensor type to an opaque pointer. 43 static std::optional<Type> convertSparseTensorTypes(Type type) { 44 if (getSparseTensorEncoding(type) != nullptr) 45 return LLVM::LLVMPointerType::get(type.getContext()); 46 return std::nullopt; 47 } 48 49 /// Generates call to lookup a level-size. N.B., this only generates 50 /// the raw function call, and therefore (intentionally) does not perform 51 /// any dim<->lvl conversion or other logic. 52 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor, 53 uint64_t lvl) { 54 StringRef name = "sparseLvlSize"; 55 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)}; 56 Type iTp = builder.getIndexType(); 57 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 58 .getResult(0); 59 } 60 61 /// Generates call to lookup a dimension-size. N.B., this only generates 62 /// the raw function call, and therefore (intentionally) does not perform 63 /// any dim<->lvl conversion or other logic. 64 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor, 65 uint64_t dim) { 66 StringRef name = "sparseDimSize"; 67 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)}; 68 Type iTp = builder.getIndexType(); 69 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 70 .getResult(0); 71 } 72 73 /// Looks up a level-size by returning a statically-computed constant 74 /// (when possible), or by calling `genLvlSizeCall` (when dynamic). 75 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc, 76 SparseTensorType stt, Value tensor, 77 Level lvl) { 78 // Only sparse tensors have "levels" to query. 79 assert(stt.hasEncoding()); 80 // TODO: The following implementation only handles permutations; 81 // we'll need to generalize this to handle arbitrary AffineExpr. 82 // 83 // There's no need to assert `isPermutation` here: because 84 // `getDimPosition` checks that the expr isa `AffineDimExpr`, 85 // which is all we care about (for supporting permutations). 86 const Dimension dim = 87 stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl); 88 const Size sz = stt.getDynamicDimSize(dim); 89 if (!ShapedType::isDynamic(sz)) 90 return constantIndex(builder, loc, sz); 91 // If we cannot statically compute the size from the shape, then we 92 // must dynamically query it. (In principle we could also dynamically 93 // compute it, but since we already did so to construct the `tensor` 94 // in the first place, we might as well query rather than recompute.) 95 return genLvlSizeCall(builder, loc, tensor, lvl); 96 } 97 98 /// Looks up a dimension-size by returning a constant from the shape 99 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes 100 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes 101 /// of dense tensors). 102 static Value createOrFoldDimCall(OpBuilder &builder, Location loc, 103 SparseTensorType stt, Value tensor, 104 Dimension dim) { 105 const Size sz = stt.getDynamicDimSize(dim); 106 if (!ShapedType::isDynamic(sz)) 107 return constantIndex(builder, loc, sz); 108 if (stt.hasEncoding()) 109 return genDimSizeCall(builder, loc, tensor, dim); 110 return linalg::createOrFoldDimOp(builder, loc, tensor, dim); 111 } 112 113 /// Populates the array with the dimension-sizes of the given tensor. 114 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, 115 Value tensor, SmallVectorImpl<Value> &out) { 116 const Dimension dimRank = stt.getDimRank(); 117 out.clear(); 118 out.reserve(dimRank); 119 for (Dimension d = 0; d < dimRank; d++) 120 out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d)); 121 } 122 123 /// Returns an array with the dimension-sizes of the given tensor. 124 /// If the *tensor* parameters is null, the tensor type is assumed to have a 125 /// static shape. 126 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc, 127 SparseTensorType stt, 128 Value tensor = Value()) { 129 SmallVector<Value> out; 130 fillDimSizes(builder, loc, stt, tensor, out); 131 return out; 132 } 133 134 /// Generates an uninitialized buffer of the given size and type, 135 /// but returns it as type `memref<? x $tp>` (rather than as type 136 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, 137 /// this buffer must be explicitly deallocated by client. 138 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { 139 auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); 140 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); 141 } 142 143 /// Generates a temporary buffer for the level-types of the given encoding. 144 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, 145 SparseTensorType stt) { 146 SmallVector<Value> lvlTypes; 147 lvlTypes.reserve(stt.getLvlRank()); 148 for (const auto lt : stt.getEncoding().getLvlTypes()) 149 lvlTypes.push_back(constantLevelTypeEncoding(builder, loc, lt)); 150 return allocaBuffer(builder, loc, lvlTypes); 151 } 152 153 /// Extracts the bare (aligned) pointers that point to the tensor. 154 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc, 155 Value tensor) { 156 auto buf = genToMemref(builder, loc, tensor); 157 return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf); 158 } 159 160 /// Generates a temporary buffer for the level-types of the given encoding. 161 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc, 162 ValueRange lvlTensors, Value valTensor) { 163 SmallVector<Value> lvlBarePtrs; 164 lvlBarePtrs.reserve(lvlTensors.size() + 1); 165 // Passing in lvl buffer pointers. 166 for (const auto lvl : lvlTensors) 167 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl)); 168 169 // Passing in value buffer pointers. 170 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor)); 171 Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>( 172 loc, allocaBuffer(builder, loc, lvlBarePtrs)); 173 Value idxCast = 174 builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr); 175 return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder), 176 idxCast); 177 } 178 179 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`: 180 /// the "swiss army knife" method of the sparse runtime support library 181 /// for materializing sparse tensors into the computation. This abstraction 182 /// reduces the need for modifications when the API changes. 183 class NewCallParams final { 184 public: 185 /// Allocates the `ValueRange` for the `func::CallOp` parameters. 186 NewCallParams(OpBuilder &builder, Location loc) 187 : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {} 188 189 /// Initializes all static parameters (i.e., those which indicate 190 /// type-level information such as the encoding and sizes), generating 191 /// MLIR buffers as needed, and returning `this` for method chaining. 192 NewCallParams &genBuffers(SparseTensorType stt, 193 ArrayRef<Value> dimSizesValues, 194 Value dimSizesBuffer = Value()) { 195 assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank())); 196 // Sparsity annotations. 197 params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt); 198 // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers. 199 params[kParamDimSizes] = dimSizesBuffer 200 ? dimSizesBuffer 201 : allocaBuffer(builder, loc, dimSizesValues); 202 SmallVector<Value> lvlSizesValues; // unused 203 params[kParamLvlSizes] = genMapBuffers( 204 builder, loc, stt, dimSizesValues, params[kParamDimSizes], 205 lvlSizesValues, params[kParamDim2Lvl], params[kParamLvl2Dim]); 206 // Secondary and primary types encoding. 207 const auto enc = stt.getEncoding(); 208 params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc); 209 params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc); 210 params[kParamValTp] = 211 constantPrimaryTypeEncoding(builder, loc, stt.getElementType()); 212 // Return `this` for method chaining. 213 return *this; 214 } 215 216 /// Checks whether all the static parameters have been initialized. 217 bool isInitialized() const { 218 for (unsigned i = 0; i < kNumStaticParams; ++i) 219 if (!params[i]) 220 return false; 221 return true; 222 } 223 224 /// Generates a function call, with the current static parameters 225 /// and the given dynamic arguments. 226 Value genNewCall(Action action, Value ptr = Value()) { 227 assert(isInitialized() && "Must initialize before genNewCall"); 228 StringRef name = "newSparseTensor"; 229 params[kParamAction] = constantAction(builder, loc, action); 230 params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp); 231 return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On) 232 .getResult(0); 233 } 234 235 private: 236 static constexpr unsigned kNumStaticParams = 8; 237 static constexpr unsigned kNumDynamicParams = 2; 238 static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams; 239 static constexpr unsigned kParamDimSizes = 0; 240 static constexpr unsigned kParamLvlSizes = 1; 241 static constexpr unsigned kParamLvlTypes = 2; 242 static constexpr unsigned kParamDim2Lvl = 3; 243 static constexpr unsigned kParamLvl2Dim = 4; 244 static constexpr unsigned kParamPosTp = 5; 245 static constexpr unsigned kParamCrdTp = 6; 246 static constexpr unsigned kParamValTp = 7; 247 static constexpr unsigned kParamAction = 8; 248 static constexpr unsigned kParamPtr = 9; 249 250 OpBuilder &builder; 251 Location loc; 252 Type pTp; 253 Value params[kNumParams]; 254 }; 255 256 /// Generates a call to obtain the values array. 257 static Value genValuesCall(OpBuilder &builder, Location loc, 258 SparseTensorType stt, Value ptr) { 259 auto eltTp = stt.getElementType(); 260 auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp); 261 SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)}; 262 return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On) 263 .getResult(0); 264 } 265 266 /// Generates a call to obtain the positions array. 267 static Value genPositionsCall(OpBuilder &builder, Location loc, 268 SparseTensorType stt, Value ptr, Level l) { 269 Type posTp = stt.getPosType(); 270 auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp); 271 Value lvl = constantIndex(builder, loc, l); 272 SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)}; 273 return createFuncCall(builder, loc, name, resTp, {ptr, lvl}, 274 EmitCInterface::On) 275 .getResult(0); 276 } 277 278 /// Generates a call to obtain the coordindates array. 279 static Value genCoordinatesCall(OpBuilder &builder, Location loc, 280 SparseTensorType stt, Value ptr, Level l) { 281 Type crdTp = stt.getCrdType(); 282 auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp); 283 Value lvl = constantIndex(builder, loc, l); 284 SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)}; 285 return createFuncCall(builder, loc, name, resTp, {ptr, lvl}, 286 EmitCInterface::On) 287 .getResult(0); 288 } 289 290 //===----------------------------------------------------------------------===// 291 // Conversion rules. 292 //===----------------------------------------------------------------------===// 293 294 /// Sparse conversion rule for returns. 295 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 296 public: 297 using OpConversionPattern::OpConversionPattern; 298 LogicalResult 299 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 300 ConversionPatternRewriter &rewriter) const override { 301 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 302 return success(); 303 } 304 }; 305 306 /// Sparse conversion rule for accessing level-sizes. 307 class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> { 308 public: 309 using OpConversionPattern::OpConversionPattern; 310 LogicalResult 311 matchAndRewrite(LvlOp op, OpAdaptor adaptor, 312 ConversionPatternRewriter &rewriter) const override { 313 const auto stt = getSparseTensorType(op.getSource()); 314 // Only rewrite sparse DimOp. 315 if (!stt.hasEncoding()) 316 return failure(); 317 318 // Only rewrite DimOp with constant index. 319 std::optional<int64_t> lvl = op.getConstantLvlIndex(); 320 321 if (!lvl) 322 return failure(); 323 324 // By now, if the level size is constant, the operation should have already 325 // been folded by LvlOp's folder, so we generate the call unconditionally. 326 Value src = adaptor.getOperands()[0]; 327 rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl)); 328 return success(); 329 } 330 }; 331 332 /// Sparse conversion rule for trivial tensor casts. 333 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 334 public: 335 using OpConversionPattern::OpConversionPattern; 336 LogicalResult 337 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 338 ConversionPatternRewriter &rewriter) const override { 339 // Only rewrite identically annotated source/dest. 340 auto encDst = getSparseTensorEncoding(op.getType()); 341 auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 342 if (!encDst || encDst != encSrc) 343 return failure(); 344 rewriter.replaceOp(op, adaptor.getOperands()); 345 return success(); 346 } 347 }; 348 349 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> { 350 public: 351 using OpConversionPattern::OpConversionPattern; 352 LogicalResult 353 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor, 354 ConversionPatternRewriter &rewriter) const override { 355 // Simply fold the operation. 356 rewriter.replaceOp(op, adaptor.getSource()); 357 return success(); 358 } 359 }; 360 361 /// Sparse conversion rule for the new operator. 362 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 363 public: 364 using OpConversionPattern::OpConversionPattern; 365 LogicalResult 366 matchAndRewrite(NewOp op, OpAdaptor adaptor, 367 ConversionPatternRewriter &rewriter) const override { 368 Location loc = op.getLoc(); 369 const auto stt = getSparseTensorType(op); 370 if (!stt.hasEncoding()) 371 return failure(); 372 // Construct the `reader` opening method calls. 373 SmallVector<Value> dimSizesValues; 374 Value dimSizesBuffer; 375 Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0], 376 dimSizesValues, dimSizesBuffer); 377 // Use the `reader` to parse the file. 378 Value tensor = NewCallParams(rewriter, loc) 379 .genBuffers(stt, dimSizesValues, dimSizesBuffer) 380 .genNewCall(Action::kFromReader, reader); 381 // Free the memory for `reader`. 382 createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, 383 EmitCInterface::Off); 384 rewriter.replaceOp(op, tensor); 385 return success(); 386 } 387 }; 388 389 /// Sparse conversion rule for the alloc operator. 390 /// TODO(springerm): remove when bufferization.alloc_tensor is gone 391 class SparseTensorAllocConverter 392 : public OpConversionPattern<bufferization::AllocTensorOp> { 393 public: 394 using OpConversionPattern::OpConversionPattern; 395 LogicalResult 396 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, 397 ConversionPatternRewriter &rewriter) const override { 398 const auto stt = getSparseTensorType(op); 399 if (!stt.hasEncoding()) 400 return failure(); 401 if (op.getCopy()) 402 return rewriter.notifyMatchFailure(op, "alloc copy not implemented"); 403 // Gather all dimension sizes as SSA values. 404 Location loc = op.getLoc(); 405 const Dimension dimRank = stt.getDimRank(); 406 SmallVector<Value> dimSizesValues; 407 dimSizesValues.reserve(dimRank); 408 unsigned operandCtr = 0; 409 for (Dimension d = 0; d < dimRank; d++) { 410 dimSizesValues.push_back( 411 stt.isDynamicDim(d) 412 ? adaptor.getOperands()[operandCtr++] 413 : constantIndex(rewriter, loc, op.getStaticSize(d))); 414 } 415 // Generate the call to construct empty tensor. The sizes are 416 // explicitly defined by the arguments to the alloc operator. 417 rewriter.replaceOp(op, NewCallParams(rewriter, loc) 418 .genBuffers(stt, dimSizesValues) 419 .genNewCall(Action::kEmpty)); 420 return success(); 421 } 422 }; 423 424 /// Sparse conversion rule for the empty tensor. 425 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { 426 public: 427 using OpConversionPattern::OpConversionPattern; 428 LogicalResult 429 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, 430 ConversionPatternRewriter &rewriter) const override { 431 Location loc = op.getLoc(); 432 const auto stt = getSparseTensorType(op); 433 if (!stt.hasEncoding()) 434 return failure(); 435 // Gather all dimension sizes as SSA values. 436 const Dimension dimRank = stt.getDimRank(); 437 SmallVector<Value> dimSizesValues; 438 dimSizesValues.reserve(dimRank); 439 auto shape = op.getType().getShape(); 440 unsigned operandCtr = 0; 441 for (Dimension d = 0; d < dimRank; d++) { 442 dimSizesValues.push_back(stt.isDynamicDim(d) 443 ? adaptor.getOperands()[operandCtr++] 444 : constantIndex(rewriter, loc, shape[d])); 445 } 446 // Generate the call to construct empty tensor. The sizes are 447 // explicitly defined by the arguments to the alloc operator. 448 rewriter.replaceOp(op, NewCallParams(rewriter, loc) 449 .genBuffers(stt, dimSizesValues) 450 .genNewCall(Action::kEmpty)); 451 return success(); 452 } 453 }; 454 455 /// Sparse conversion rule for the convert operator. 456 class SparseTensorReorderCOOConverter 457 : public OpConversionPattern<ReorderCOOOp> { 458 public: 459 using OpConversionPattern::OpConversionPattern; 460 461 LogicalResult 462 matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor, 463 ConversionPatternRewriter &rewriter) const override { 464 const Location loc = op->getLoc(); 465 const auto srcTp = getSparseTensorType(op.getInputCoo()); 466 const auto dstTp = getSparseTensorType(op); 467 468 const Value src = adaptor.getInputCoo(); 469 470 NewCallParams params(rewriter, loc); 471 SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, srcTp, src); 472 rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizesValues) 473 .genNewCall(Action::kSortCOOInPlace, src)); 474 475 return success(); 476 } 477 }; 478 479 /// Sparse conversion rule for the dealloc operator. 480 class SparseTensorDeallocConverter 481 : public OpConversionPattern<bufferization::DeallocTensorOp> { 482 public: 483 using OpConversionPattern::OpConversionPattern; 484 LogicalResult 485 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, 486 ConversionPatternRewriter &rewriter) const override { 487 if (!getSparseTensorType(op.getTensor()).hasEncoding()) 488 return failure(); 489 StringRef name = "delSparseTensor"; 490 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 491 EmitCInterface::Off); 492 rewriter.eraseOp(op); 493 return success(); 494 } 495 }; 496 497 /// Sparse conversion rule for position accesses. 498 class SparseTensorToPositionsConverter 499 : public OpConversionPattern<ToPositionsOp> { 500 public: 501 using OpConversionPattern::OpConversionPattern; 502 LogicalResult 503 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor, 504 ConversionPatternRewriter &rewriter) const override { 505 auto stt = getSparseTensorType(op.getTensor()); 506 auto poss = genPositionsCall(rewriter, op.getLoc(), stt, 507 adaptor.getTensor(), op.getLevel()); 508 rewriter.replaceOp(op, poss); 509 return success(); 510 } 511 }; 512 513 /// Sparse conversion rule for coordinate accesses. 514 class SparseTensorToCoordinatesConverter 515 : public OpConversionPattern<ToCoordinatesOp> { 516 public: 517 using OpConversionPattern::OpConversionPattern; 518 LogicalResult 519 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor, 520 ConversionPatternRewriter &rewriter) const override { 521 auto stt = getSparseTensorType(op.getTensor()); 522 auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt, 523 adaptor.getTensor(), op.getLevel()); 524 // Cast the MemRef type to the type expected by the users, though these 525 // two types should be compatible at runtime. 526 if (op.getType() != crds.getType()) 527 crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds); 528 rewriter.replaceOp(op, crds); 529 return success(); 530 } 531 }; 532 533 /// Sparse conversion rule for value accesses. 534 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 535 public: 536 using OpConversionPattern::OpConversionPattern; 537 LogicalResult 538 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 539 ConversionPatternRewriter &rewriter) const override { 540 auto stt = getSparseTensorType(op.getTensor()); 541 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor()); 542 rewriter.replaceOp(op, vals); 543 return success(); 544 } 545 }; 546 547 /// Sparse conversion rule for number of entries operator. 548 class SparseNumberOfEntriesConverter 549 : public OpConversionPattern<NumberOfEntriesOp> { 550 public: 551 using OpConversionPattern::OpConversionPattern; 552 LogicalResult 553 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, 554 ConversionPatternRewriter &rewriter) const override { 555 // Query values array size for the actually stored values size. 556 auto stt = getSparseTensorType(op.getTensor()); 557 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor()); 558 auto zero = constantIndex(rewriter, op.getLoc(), 0); 559 rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero); 560 return success(); 561 } 562 }; 563 564 /// Sparse conversion rule for tensor rematerialization. 565 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 566 public: 567 using OpConversionPattern::OpConversionPattern; 568 LogicalResult 569 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 570 ConversionPatternRewriter &rewriter) const override { 571 if (op.getHasInserts()) { 572 // Finalize any pending insertions. 573 StringRef name = "endLexInsert"; 574 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 575 EmitCInterface::Off); 576 } 577 rewriter.replaceOp(op, adaptor.getOperands()); 578 return success(); 579 } 580 }; 581 582 /// Sparse conversion rule for the insertion operator. 583 class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> { 584 public: 585 using OpConversionPattern::OpConversionPattern; 586 LogicalResult 587 matchAndRewrite(InsertOp op, OpAdaptor adaptor, 588 ConversionPatternRewriter &rewriter) const override { 589 // Note that the current regime only allows for strict lexicographic 590 // coordinate order. All values are passed by reference through stack 591 // allocated memrefs. 592 Location loc = op->getLoc(); 593 const auto stt = getSparseTensorType(op.getTensor()); 594 const auto elemTp = stt.getElementType(); 595 const Level lvlRank = stt.getLvlRank(); 596 Value lvlCoords, vref; 597 { 598 OpBuilder::InsertionGuard guard(rewriter); 599 Operation *loop = op; 600 // Finds the outermost loop. 601 while (auto l = loop->getParentOfType<LoopLikeOpInterface>()) 602 loop = l; 603 604 if (llvm::isa<LoopLikeOpInterface>(loop)) { 605 // Hoists alloca outside the loop to avoid stack overflow. 606 rewriter.setInsertionPoint(loop); 607 } 608 lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 609 vref = genAllocaScalar(rewriter, loc, elemTp); 610 } 611 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords()); 612 rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref); 613 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 614 createFuncCall(rewriter, loc, name, {}, 615 {adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On); 616 rewriter.replaceOp(op, adaptor.getTensor()); 617 return success(); 618 } 619 }; 620 621 /// Sparse conversion rule for the expand operator. 622 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 623 public: 624 using OpConversionPattern::OpConversionPattern; 625 LogicalResult 626 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 627 ConversionPatternRewriter &rewriter) const override { 628 Location loc = op->getLoc(); 629 const auto srcTp = getSparseTensorType(op.getTensor()); 630 Type eltType = srcTp.getElementType(); 631 Type boolType = rewriter.getIntegerType(1); 632 Type idxType = rewriter.getIndexType(); 633 // All initialization should be done on entry of the loop nest. 634 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); 635 // Get the cardinality of valid coordinates for the innermost level. 636 Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(), 637 srcTp.getLvlRank() - 1); 638 // Allocate temporary buffers for values, filled-switch, and coordinates. 639 // We do not use stack buffers for this, since the expanded size may 640 // be rather large (as it envelops a single expanded dense dimension). 641 Value values = genAlloc(rewriter, loc, sz, eltType); 642 Value filled = genAlloc(rewriter, loc, sz, boolType); 643 Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType); 644 Value zero = constantZero(rewriter, loc, idxType); 645 // Reset the values/filled-switch to all-zero/false. Note that this 646 // introduces an O(N) operation into the computation, but this reset 647 // operation is amortized over the innermost loops for the access 648 // pattern expansion. As noted in the operation doc, we would like 649 // to amortize this setup cost even between kernels. 650 rewriter.create<linalg::FillOp>( 651 loc, ValueRange{constantZero(rewriter, loc, eltType)}, 652 ValueRange{values}); 653 rewriter.create<linalg::FillOp>( 654 loc, ValueRange{constantZero(rewriter, loc, boolType)}, 655 ValueRange{filled}); 656 // Replace expansion op with these buffers and initial coordinate. 657 assert(op.getNumResults() == 4); 658 rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero}); 659 return success(); 660 } 661 }; 662 663 /// Sparse conversion rule for the compress operator. 664 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 665 public: 666 using OpConversionPattern::OpConversionPattern; 667 LogicalResult 668 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 669 ConversionPatternRewriter &rewriter) const override { 670 Location loc = op->getLoc(); 671 // Note that this method call resets the values/filled-switch back to 672 // all-zero/false by only iterating over the set elements, so the 673 // complexity remains proportional to the sparsity of the expanded 674 // access pattern. 675 Value values = adaptor.getValues(); 676 Value filled = adaptor.getFilled(); 677 Value added = adaptor.getAdded(); 678 Value count = adaptor.getCount(); 679 Value tensor = adaptor.getTensor(); 680 const auto stt = getSparseTensorType(op.getTensor()); 681 const Type elemTp = stt.getElementType(); 682 const Level lvlRank = stt.getLvlRank(); 683 auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 684 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords()); 685 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 686 createFuncCall(rewriter, loc, name, {}, 687 {tensor, lvlCoords, values, filled, added, count}, 688 EmitCInterface::On); 689 rewriter.replaceOp(op, adaptor.getTensor()); 690 // Deallocate the buffers on exit of the loop nest. 691 Operation *parent = getTop(op); 692 rewriter.setInsertionPointAfter(parent); 693 rewriter.create<memref::DeallocOp>(loc, values); 694 rewriter.create<memref::DeallocOp>(loc, filled); 695 rewriter.create<memref::DeallocOp>(loc, added); 696 return success(); 697 } 698 }; 699 700 /// Sparse conversion rule for the sparse_tensor.assemble operator. 701 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> { 702 public: 703 using OpConversionPattern::OpConversionPattern; 704 LogicalResult 705 matchAndRewrite(AssembleOp op, OpAdaptor adaptor, 706 ConversionPatternRewriter &rewriter) const override { 707 const Location loc = op->getLoc(); 708 const auto dstTp = getSparseTensorType(op.getResult()); 709 assert(dstTp.hasStaticDimShape()); 710 SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, dstTp); 711 // Use a library method to transfer the external buffers from 712 // clients to the internal SparseTensorStorage. Since we cannot 713 // assume clients transfer ownership of the buffers, this method 714 // will copy all data over into a new SparseTensorStorage. 715 Value dst = 716 NewCallParams(rewriter, loc) 717 .genBuffers(dstTp.withoutDimToLvl(), dimSizesValues) 718 .genNewCall(Action::kPack, 719 genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(), 720 adaptor.getValues())); 721 rewriter.replaceOp(op, dst); 722 return success(); 723 } 724 }; 725 726 /// Sparse conversion rule for the sparse_tensor.disassemble operator. 727 class SparseTensorDisassembleConverter 728 : public OpConversionPattern<DisassembleOp> { 729 public: 730 using OpConversionPattern::OpConversionPattern; 731 LogicalResult 732 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor, 733 ConversionPatternRewriter &rewriter) const override { 734 // We simply expose the buffers to the external client. This 735 // assumes the client only reads the buffers (usually copying it 736 // to the external data structures, such as numpy arrays). 737 Location loc = op->getLoc(); 738 auto stt = getSparseTensorType(op.getTensor()); 739 SmallVector<Value> retVal; 740 SmallVector<Value> retLen; 741 // Get the values buffer first. 742 auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor()); 743 auto valLenTp = op.getValLen().getType(); 744 auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0); 745 retVal.push_back(vals); 746 retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp)); 747 // Then get the positions and coordinates buffers. 748 const Level lvlRank = stt.getLvlRank(); 749 Level trailCOOLen = 0; 750 for (Level l = 0; l < lvlRank; l++) { 751 if (!stt.isUniqueLvl(l) && 752 (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) { 753 // A `(loose)compressed_nu` level marks the start of trailing COO 754 // start level. Since the target coordinate buffer used for trailing 755 // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA 756 // scheme, we cannot simply use the internal buffers. 757 trailCOOLen = lvlRank - l; 758 break; 759 } 760 if (stt.isWithPos(l)) { 761 auto poss = 762 genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l); 763 auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0); 764 auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1]; 765 retVal.push_back(poss); 766 retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp)); 767 } 768 if (stt.isWithCrd(l)) { 769 auto crds = 770 genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l); 771 auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0); 772 auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1]; 773 retVal.push_back(crds); 774 retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp)); 775 } 776 } 777 // Handle AoS vs. SoA mismatch for COO. 778 if (trailCOOLen != 0) { 779 uint64_t cooStartLvl = lvlRank - trailCOOLen; 780 assert(!stt.isUniqueLvl(cooStartLvl) && 781 (stt.isCompressedLvl(cooStartLvl) || 782 stt.isLooseCompressedLvl(cooStartLvl))); 783 // Positions. 784 auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), 785 cooStartLvl); 786 auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0); 787 auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1]; 788 retVal.push_back(poss); 789 retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp)); 790 // Coordinates, copied over with: 791 // for (i = 0; i < crdLen; i++) 792 // buf[i][0] = crd0[i]; buf[i][1] = crd1[i]; 793 auto buf = 794 genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]); 795 auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), 796 cooStartLvl); 797 auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), 798 cooStartLvl + 1); 799 auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0); 800 auto two = constantIndex(rewriter, loc, 2); 801 auto bufLen = rewriter.create<arith::MulIOp>(loc, crdLen, two); 802 Type indexType = rewriter.getIndexType(); 803 auto zero = constantZero(rewriter, loc, indexType); 804 auto one = constantOne(rewriter, loc, indexType); 805 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, zero, crdLen, one); 806 auto idx = forOp.getInductionVar(); 807 rewriter.setInsertionPointToStart(forOp.getBody()); 808 auto c0 = rewriter.create<memref::LoadOp>(loc, crds0, idx); 809 auto c1 = rewriter.create<memref::LoadOp>(loc, crds1, idx); 810 SmallVector<Value> args; 811 args.push_back(idx); 812 args.push_back(zero); 813 rewriter.create<memref::StoreOp>(loc, c0, buf, args); 814 args[1] = one; 815 rewriter.create<memref::StoreOp>(loc, c1, buf, args); 816 rewriter.setInsertionPointAfter(forOp); 817 auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1]; 818 retVal.push_back(buf); 819 retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp)); 820 } 821 // Converts MemRefs back to Tensors. 822 assert(retVal.size() + retLen.size() == op.getNumResults()); 823 for (unsigned i = 0, sz = retVal.size(); i < sz; i++) { 824 auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]); 825 retVal[i] = 826 rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor); 827 } 828 // Appends the actual memory length used in each buffer returned. 829 retVal.append(retLen.begin(), retLen.end()); 830 rewriter.replaceOp(op, retVal); 831 return success(); 832 } 833 }; 834 835 } // namespace 836 837 //===----------------------------------------------------------------------===// 838 // Sparse tensor type conversion into opaque pointer. 839 //===----------------------------------------------------------------------===// 840 841 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() { 842 addConversion([](Type type) { return type; }); 843 addConversion(convertSparseTensorTypes); 844 } 845 846 //===----------------------------------------------------------------------===// 847 // Public method for populating conversion rules. 848 //===----------------------------------------------------------------------===// 849 850 /// Populates the given patterns list with conversion rules required for 851 /// the sparsification of linear algebra operations. 852 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 853 RewritePatternSet &patterns) { 854 patterns 855 .add<SparseReturnConverter, SparseTensorLvlOpConverter, 856 SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter, 857 SparseTensorAllocConverter, SparseTensorEmptyConverter, 858 SparseTensorDeallocConverter, SparseTensorReorderCOOConverter, 859 SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter, 860 SparseTensorToValuesConverter, SparseNumberOfEntriesConverter, 861 SparseTensorLoadConverter, SparseTensorInsertConverter, 862 SparseTensorExpandConverter, SparseTensorCompressConverter, 863 SparseTensorAssembleConverter, SparseTensorDisassembleConverter>( 864 typeConverter, patterns.getContext()); 865 } 866