1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts sparse tensor primitives into calls into a runtime 10 // support library. Sparse tensor types are converted into opaque pointers 11 // to the underlying sparse storage schemes. The use of opaque pointers 12 // together with runtime support library keeps the conversion relatively 13 // simple, but at the expense of IR opacity, which obscures opportunities 14 // for subsequent optimization of the IR. An alternative is provided by 15 // the SparseTensorCodegen pass. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "CodegenUtils.h" 20 21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 23 #include "mlir/Dialect/Linalg/Utils/Utils.h" 24 #include "mlir/Dialect/MemRef/IR/MemRef.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 30 #include "mlir/Dialect/Tensor/IR/Tensor.h" 31 #include "mlir/Transforms/DialectConversion.h" 32 33 using namespace mlir; 34 using namespace mlir::sparse_tensor; 35 36 namespace { 37 38 //===----------------------------------------------------------------------===// 39 // Helper methods. 40 //===----------------------------------------------------------------------===// 41 42 /// Maps each sparse tensor type to an opaque pointer. 43 static std::optional<Type> convertSparseTensorTypes(Type type) { 44 if (getSparseTensorEncoding(type) != nullptr) 45 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); 46 return std::nullopt; 47 } 48 49 /// Replaces the `op` with a `CallOp` to the `getFunc()` function reference. 50 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op, 51 StringRef name, TypeRange resultType, 52 ValueRange operands, 53 EmitCInterface emitCInterface) { 54 auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands, 55 emitCInterface); 56 return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn, 57 operands); 58 } 59 60 /// Generates call to lookup a level-size. N.B., this only generates 61 /// the raw function call, and therefore (intentionally) does not perform 62 /// any dim<->lvl conversion or other logic. 63 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor, 64 uint64_t lvl) { 65 StringRef name = "sparseLvlSize"; 66 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)}; 67 Type iTp = builder.getIndexType(); 68 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 69 .getResult(0); 70 } 71 72 /// Generates call to lookup a dimension-size. N.B., this only generates 73 /// the raw function call, and therefore (intentionally) does not perform 74 /// any dim<->lvl conversion or other logic. 75 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor, 76 uint64_t dim) { 77 StringRef name = "sparseDimSize"; 78 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)}; 79 Type iTp = builder.getIndexType(); 80 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 81 .getResult(0); 82 } 83 84 /// Looks up a level-size by returning a statically-computed constant 85 /// (when possible), or by calling `genLvlSizeCall` (when dynamic). 86 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc, 87 SparseTensorType stt, Value tensor, 88 Level lvl) { 89 // Only sparse tensors have "levels" to query. 90 assert(stt.hasEncoding()); 91 // TODO: The following implementation only handles permutations; 92 // we'll need to generalize this to handle arbitrary AffineExpr. 93 // 94 // There's no need to assert `isPermutation` here: because 95 // `getDimPosition` checks that the expr isa `AffineDimExpr`, 96 // which is all we care about (for supporting permutations). 97 const Dimension dim = 98 stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl); 99 if (const auto sz = stt.getStaticDimSize(dim)) 100 return constantIndex(builder, loc, *sz); 101 // If we cannot statically compute the size from the shape, then we 102 // must dynamically query it. (In principle we could also dynamically 103 // compute it, but since we already did so to construct the `tensor` 104 // in the first place, we might as well query rather than recompute.) 105 return genLvlSizeCall(builder, loc, tensor, lvl); 106 } 107 108 /// Looks up a dimension-size by returning a constant from the shape 109 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes 110 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes 111 /// of dense tensors). 112 static Value createOrFoldDimCall(OpBuilder &builder, Location loc, 113 SparseTensorType stt, Value tensor, 114 Dimension dim) { 115 if (const auto sz = stt.getStaticDimSize(dim)) 116 return constantIndex(builder, loc, *sz); 117 if (stt.hasEncoding()) 118 return genDimSizeCall(builder, loc, tensor, dim); 119 return linalg::createOrFoldDimOp(builder, loc, tensor, dim); 120 } 121 122 /// Populates the array with the dimension-sizes of the given tensor. 123 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, 124 Value tensor, SmallVectorImpl<Value> &out) { 125 const Dimension dimRank = stt.getDimRank(); 126 out.clear(); 127 out.reserve(dimRank); 128 for (Dimension d = 0; d < dimRank; d++) 129 out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d)); 130 } 131 132 /// Returns an array with the dimension-sizes of the given tensor. 133 /// If the *tensor* parameters is null, the tensor type is assumed to have a 134 /// static shape. 135 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc, 136 SparseTensorType stt, 137 Value tensor = Value()) { 138 SmallVector<Value> out; 139 fillDimSizes(builder, loc, stt, tensor, out); 140 return out; 141 } 142 143 /// Generates an uninitialized buffer of the given size and type, 144 /// but returns it as type `memref<? x $tp>` (rather than as type 145 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, 146 /// this buffer must be explicitly deallocated by client. 147 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { 148 auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); 149 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); 150 } 151 152 /// Generates a temporary buffer for the level-types of the given encoding. 153 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, 154 SparseTensorType stt) { 155 SmallVector<Value> lvlTypes; 156 lvlTypes.reserve(stt.getLvlRank()); 157 for (const auto dlt : stt.getEncoding().getLvlTypes()) 158 lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt)); 159 return allocaBuffer(builder, loc, lvlTypes); 160 } 161 162 /// Extracts the bare (aligned) pointers that point to the tensor. 163 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc, 164 Value tensor) { 165 auto buf = genToMemref(builder, loc, tensor); 166 return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf); 167 } 168 169 /// Generates a temporary buffer for the level-types of the given encoding. 170 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc, 171 ValueRange lvlTensors, Value valTensor) { 172 SmallVector<Value> lvlBarePtrs; 173 lvlBarePtrs.reserve(lvlTensors.size() + 1); 174 // Passing in lvl buffer pointers. 175 for (const auto lvl : lvlTensors) 176 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl)); 177 178 // Passing in value buffer pointers. 179 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor)); 180 Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>( 181 loc, allocaBuffer(builder, loc, lvlBarePtrs)); 182 Value idxCast = 183 builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr); 184 return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder), 185 idxCast); 186 } 187 188 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`: 189 /// the "swiss army knife" method of the sparse runtime support library 190 /// for materializing sparse tensors into the computation. This abstraction 191 /// reduces the need for modifications when the API changes. 192 class NewCallParams final { 193 public: 194 /// Allocates the `ValueRange` for the `func::CallOp` parameters. 195 NewCallParams(OpBuilder &builder, Location loc) 196 : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {} 197 198 /// Initializes all static parameters (i.e., those which indicate 199 /// type-level information such as the encoding and sizes), generating 200 /// MLIR buffers as needed, and returning `this` for method chaining. 201 NewCallParams &genBuffers(SparseTensorType stt, 202 ArrayRef<Value> dimSizesValues) { 203 assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank())); 204 // Sparsity annotations. 205 params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt); 206 // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers. 207 params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizesValues); 208 params[kParamLvlSizes] = 209 genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes], 210 params[kParamDim2Lvl], params[kParamLvl2Dim]); 211 // Secondary and primary types encoding. 212 setTemplateTypes(stt); 213 // Finally, make note that initialization is complete. 214 assert(isInitialized() && "Initialization failed"); 215 // And return `this` for method chaining. 216 return *this; 217 } 218 219 /// (Re)sets the C++ template type parameters, and returns `this` 220 /// for method chaining. This is already done as part of `genBuffers`, 221 /// but is factored out so that it can also be called independently 222 /// whenever subsequent `genNewCall` calls want to reuse the same 223 /// buffers but different type parameters. 224 // 225 // TODO: This is only ever used by sparse2sparse-viaCOO `ConvertOp`; 226 // is there a better way to handle that than this one-off setter method? 227 NewCallParams &setTemplateTypes(SparseTensorType stt) { 228 const auto enc = stt.getEncoding(); 229 params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc); 230 params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc); 231 params[kParamValTp] = 232 constantPrimaryTypeEncoding(builder, loc, stt.getElementType()); 233 return *this; 234 } 235 236 /// Checks whether all the static parameters have been initialized. 237 bool isInitialized() const { 238 for (unsigned i = 0; i < kNumStaticParams; ++i) 239 if (!params[i]) 240 return false; 241 return true; 242 } 243 244 /// Generates a function call, with the current static parameters 245 /// and the given dynamic arguments. 246 Value genNewCall(Action action, Value ptr = Value()) { 247 assert(isInitialized() && "Must initialize before genNewCall"); 248 StringRef name = "newSparseTensor"; 249 params[kParamAction] = constantAction(builder, loc, action); 250 params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp); 251 return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On) 252 .getResult(0); 253 } 254 255 private: 256 static constexpr unsigned kNumStaticParams = 8; 257 static constexpr unsigned kNumDynamicParams = 2; 258 static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams; 259 static constexpr unsigned kParamDimSizes = 0; 260 static constexpr unsigned kParamLvlSizes = 1; 261 static constexpr unsigned kParamLvlTypes = 2; 262 static constexpr unsigned kParamDim2Lvl = 3; 263 static constexpr unsigned kParamLvl2Dim = 4; 264 static constexpr unsigned kParamPosTp = 5; 265 static constexpr unsigned kParamCrdTp = 6; 266 static constexpr unsigned kParamValTp = 7; 267 static constexpr unsigned kParamAction = 8; 268 static constexpr unsigned kParamPtr = 9; 269 270 OpBuilder &builder; 271 Location loc; 272 Type pTp; 273 Value params[kNumParams]; 274 }; 275 276 /// Generates a call to obtain the values array. 277 static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp, 278 ValueRange ptr) { 279 SmallString<15> name{"sparseValues", 280 primaryTypeFunctionSuffix(tp.getElementType())}; 281 return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On) 282 .getResult(0); 283 } 284 285 /// Generates a call to release/delete a `SparseTensorCOO`. 286 static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp, 287 Value coo) { 288 SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)}; 289 createFuncCall(builder, loc, name, {}, coo, EmitCInterface::Off); 290 } 291 292 //===----------------------------------------------------------------------===// 293 // Conversion rules. 294 //===----------------------------------------------------------------------===// 295 296 /// Sparse conversion rule for returns. 297 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 298 public: 299 using OpConversionPattern::OpConversionPattern; 300 LogicalResult 301 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 302 ConversionPatternRewriter &rewriter) const override { 303 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 304 return success(); 305 } 306 }; 307 308 /// Sparse conversion rule for accessing dimension-sizes. 309 class SparseTensorToDimSizeConverter 310 : public OpConversionPattern<tensor::DimOp> { 311 public: 312 using OpConversionPattern::OpConversionPattern; 313 LogicalResult 314 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 315 ConversionPatternRewriter &rewriter) const override { 316 const auto stt = getSparseTensorType(op.getSource()); 317 // Only rewrite sparse DimOp. 318 if (!stt.hasEncoding()) 319 return failure(); 320 // Only rewrite DimOp with constant index. 321 std::optional<int64_t> dim = op.getConstantIndex(); 322 if (!dim) 323 return failure(); 324 // Generate the call. 325 Value src = adaptor.getOperands()[0]; 326 rewriter.replaceOp( 327 op, createOrFoldDimCall(rewriter, op->getLoc(), stt, src, *dim)); 328 return success(); 329 } 330 }; 331 332 /// Sparse conversion rule for trivial tensor casts. 333 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 334 public: 335 using OpConversionPattern::OpConversionPattern; 336 LogicalResult 337 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 338 ConversionPatternRewriter &rewriter) const override { 339 // Only rewrite identically annotated source/dest. 340 auto encDst = getSparseTensorEncoding(op.getType()); 341 auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 342 if (!encDst || encDst != encSrc) 343 return failure(); 344 rewriter.replaceOp(op, adaptor.getOperands()); 345 return success(); 346 } 347 }; 348 349 /// Sparse conversion rule for the new operator. 350 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 351 public: 352 using OpConversionPattern::OpConversionPattern; 353 LogicalResult 354 matchAndRewrite(NewOp op, OpAdaptor adaptor, 355 ConversionPatternRewriter &rewriter) const override { 356 Location loc = op.getLoc(); 357 const auto stt = getSparseTensorType(op); 358 if (!stt.hasEncoding()) 359 return failure(); 360 // Construct the reader opening method calls. 361 SmallVector<Value> dimShapesValues; 362 Value dimSizesBuffer; 363 Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0], 364 dimShapesValues, dimSizesBuffer); 365 // Now construct the lvlSizes, dim2lvl, and lvl2dim buffers. 366 Value dim2lvlBuffer; 367 Value lvl2dimBuffer; 368 Value lvlSizesBuffer = 369 genMapBuffers(rewriter, loc, stt, dimShapesValues, dimSizesBuffer, 370 dim2lvlBuffer, lvl2dimBuffer); 371 // Use the `reader` to parse the file. 372 Type opaqueTp = getOpaquePointerType(rewriter); 373 Type eltTp = stt.getElementType(); 374 Value valTp = constantPrimaryTypeEncoding(rewriter, loc, eltTp); 375 SmallVector<Value, 8> params{ 376 reader, 377 lvlSizesBuffer, 378 genLvlTypesBuffer(rewriter, loc, stt), 379 dim2lvlBuffer, 380 lvl2dimBuffer, 381 constantPosTypeEncoding(rewriter, loc, stt.getEncoding()), 382 constantCrdTypeEncoding(rewriter, loc, stt.getEncoding()), 383 valTp}; 384 Value tensor = createFuncCall(rewriter, loc, "newSparseTensorFromReader", 385 opaqueTp, params, EmitCInterface::On) 386 .getResult(0); 387 // Free the memory for `reader`. 388 createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, 389 EmitCInterface::Off); 390 rewriter.replaceOp(op, tensor); 391 return success(); 392 } 393 }; 394 395 /// Sparse conversion rule for the alloc operator. 396 /// TODO(springerm): remove when bufferization.alloc_tensor is gone 397 class SparseTensorAllocConverter 398 : public OpConversionPattern<bufferization::AllocTensorOp> { 399 public: 400 using OpConversionPattern::OpConversionPattern; 401 LogicalResult 402 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, 403 ConversionPatternRewriter &rewriter) const override { 404 if (op.getCopy()) 405 return rewriter.notifyMatchFailure(op, 406 "sparse tensor copy not implemented"); 407 Location loc = op.getLoc(); 408 const auto stt = getSparseTensorType(op); 409 if (!stt.hasEncoding()) 410 return failure(); 411 // Gather all dimension sizes as SSA values. 412 const Dimension dimRank = stt.getDimRank(); 413 SmallVector<Value> dimSizes; 414 dimSizes.reserve(dimRank); 415 unsigned operandCtr = 0; 416 for (Dimension d = 0; d < dimRank; ++d) { 417 dimSizes.push_back( 418 stt.isDynamicDim(d) 419 ? adaptor.getOperands()[operandCtr++] 420 : constantIndex(rewriter, loc, op.getStaticSize(d))); 421 } 422 // Generate the call to construct empty tensor. The sizes are 423 // explicitly defined by the arguments to the alloc operator. 424 rewriter.replaceOp(op, NewCallParams(rewriter, loc) 425 .genBuffers(stt, dimSizes) 426 .genNewCall(Action::kEmpty)); 427 return success(); 428 } 429 }; 430 431 /// Sparse conversion rule for the empty tensor. 432 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { 433 public: 434 using OpConversionPattern::OpConversionPattern; 435 LogicalResult 436 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, 437 ConversionPatternRewriter &rewriter) const override { 438 Location loc = op.getLoc(); 439 const auto stt = getSparseTensorType(op); 440 if (!stt.hasEncoding()) 441 return failure(); 442 // Gather all dimension sizes as SSA values. 443 const Dimension dimRank = stt.getDimRank(); 444 SmallVector<Value> dimSizes; 445 dimSizes.reserve(dimRank); 446 auto shape = op.getType().getShape(); 447 unsigned operandCtr = 0; 448 for (Dimension d = 0; d < dimRank; ++d) { 449 dimSizes.push_back(stt.isDynamicDim(d) 450 ? adaptor.getOperands()[operandCtr++] 451 : constantIndex(rewriter, loc, shape[d])); 452 } 453 // Generate the call to construct empty tensor. The sizes are 454 // explicitly defined by the arguments to the alloc operator. 455 rewriter.replaceOp(op, NewCallParams(rewriter, loc) 456 .genBuffers(stt, dimSizes) 457 .genNewCall(Action::kEmpty)); 458 return success(); 459 } 460 }; 461 462 /// Sparse conversion rule for the convert operator. 463 class SparseTensorReorderCOOConverter 464 : public OpConversionPattern<ReorderCOOOp> { 465 public: 466 using OpConversionPattern::OpConversionPattern; 467 468 LogicalResult 469 matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor, 470 ConversionPatternRewriter &rewriter) const override { 471 const Location loc = op->getLoc(); 472 const auto srcTp = getSparseTensorType(op.getInputCoo()); 473 const auto dstTp = getSparseTensorType(op); 474 475 const Value src = adaptor.getInputCoo(); 476 477 NewCallParams params(rewriter, loc); 478 SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src); 479 rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizes) 480 .genNewCall(Action::kSortCOOInPlace, src)); 481 482 return success(); 483 } 484 }; 485 486 /// Sparse conversion rule for the dealloc operator. 487 class SparseTensorDeallocConverter 488 : public OpConversionPattern<bufferization::DeallocTensorOp> { 489 public: 490 using OpConversionPattern::OpConversionPattern; 491 LogicalResult 492 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, 493 ConversionPatternRewriter &rewriter) const override { 494 if (!getSparseTensorType(op.getTensor()).hasEncoding()) 495 return failure(); 496 StringRef name = "delSparseTensor"; 497 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 498 EmitCInterface::Off); 499 rewriter.eraseOp(op); 500 return success(); 501 } 502 }; 503 504 /// Sparse conversion rule for position accesses. 505 class SparseTensorToPositionsConverter 506 : public OpConversionPattern<ToPositionsOp> { 507 public: 508 using OpConversionPattern::OpConversionPattern; 509 LogicalResult 510 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor, 511 ConversionPatternRewriter &rewriter) const override { 512 Type resTp = op.getType(); 513 Type posTp = cast<ShapedType>(resTp).getElementType(); 514 SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)}; 515 Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel()); 516 replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl}, 517 EmitCInterface::On); 518 return success(); 519 } 520 }; 521 522 /// Sparse conversion rule for coordinate accesses. 523 class SparseTensorToCoordinatesConverter 524 : public OpConversionPattern<ToCoordinatesOp> { 525 public: 526 using OpConversionPattern::OpConversionPattern; 527 LogicalResult 528 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor, 529 ConversionPatternRewriter &rewriter) const override { 530 // TODO: use `SparseTensorType::getCrdType` instead. 531 Type resType = op.getType(); 532 const Type crdTp = cast<ShapedType>(resType).getElementType(); 533 SmallString<19> name{"sparseCoordinates", 534 overheadTypeFunctionSuffix(crdTp)}; 535 Location loc = op->getLoc(); 536 Value lvl = constantIndex(rewriter, loc, op.getLevel()); 537 538 // The function returns a MemRef without a layout. 539 MemRefType callRetType = get1DMemRefType(crdTp, false); 540 SmallVector<Value> operands{adaptor.getTensor(), lvl}; 541 auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType, 542 operands, EmitCInterface::On); 543 Value callRet = 544 rewriter.create<func::CallOp>(loc, callRetType, fn, operands) 545 .getResult(0); 546 547 // Cast the MemRef type to the type expected by the users, though these 548 // two types should be compatible at runtime. 549 if (resType != callRetType) 550 callRet = rewriter.create<memref::CastOp>(loc, resType, callRet); 551 rewriter.replaceOp(op, callRet); 552 553 return success(); 554 } 555 }; 556 557 /// Sparse conversion rule for value accesses. 558 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 559 public: 560 using OpConversionPattern::OpConversionPattern; 561 LogicalResult 562 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 563 ConversionPatternRewriter &rewriter) const override { 564 auto resType = cast<ShapedType>(op.getType()); 565 rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType, 566 adaptor.getOperands())); 567 return success(); 568 } 569 }; 570 571 /// Sparse conversion rule for number of entries operator. 572 class SparseNumberOfEntriesConverter 573 : public OpConversionPattern<NumberOfEntriesOp> { 574 public: 575 using OpConversionPattern::OpConversionPattern; 576 LogicalResult 577 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, 578 ConversionPatternRewriter &rewriter) const override { 579 Location loc = op.getLoc(); 580 // Query values array size for the actually stored values size. 581 Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType(); 582 auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType); 583 Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands()); 584 rewriter.replaceOpWithNewOp<memref::DimOp>(op, values, 585 constantIndex(rewriter, loc, 0)); 586 return success(); 587 } 588 }; 589 590 /// Sparse conversion rule for tensor rematerialization. 591 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 592 public: 593 using OpConversionPattern::OpConversionPattern; 594 LogicalResult 595 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 596 ConversionPatternRewriter &rewriter) const override { 597 if (op.getHasInserts()) { 598 // Finalize any pending insertions. 599 StringRef name = "endInsert"; 600 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 601 EmitCInterface::Off); 602 } 603 rewriter.replaceOp(op, adaptor.getOperands()); 604 return success(); 605 } 606 }; 607 608 /// Sparse conversion rule for the insertion operator. 609 class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> { 610 public: 611 using OpConversionPattern::OpConversionPattern; 612 LogicalResult 613 matchAndRewrite(InsertOp op, OpAdaptor adaptor, 614 ConversionPatternRewriter &rewriter) const override { 615 // Note that the current regime only allows for strict lexicographic 616 // coordinate order. All values are passed by reference through stack 617 // allocated memrefs. 618 Location loc = op->getLoc(); 619 const auto stt = getSparseTensorType(op.getTensor()); 620 const auto elemTp = stt.getElementType(); 621 const Level lvlRank = stt.getLvlRank(); 622 auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 623 auto vref = genAllocaScalar(rewriter, loc, elemTp); 624 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords()); 625 rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref); 626 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 627 createFuncCall(rewriter, loc, name, {}, 628 {adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On); 629 rewriter.replaceOp(op, adaptor.getTensor()); 630 return success(); 631 } 632 }; 633 634 /// Sparse conversion rule for the expand operator. 635 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 636 public: 637 using OpConversionPattern::OpConversionPattern; 638 LogicalResult 639 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 640 ConversionPatternRewriter &rewriter) const override { 641 Location loc = op->getLoc(); 642 const auto srcTp = getSparseTensorType(op.getTensor()); 643 Type eltType = srcTp.getElementType(); 644 Type boolType = rewriter.getIntegerType(1); 645 Type idxType = rewriter.getIndexType(); 646 // All initialization should be done on entry of the loop nest. 647 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); 648 // Get the cardinality of valid coordinates for the innermost level. 649 Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(), 650 srcTp.getLvlRank() - 1); 651 // Allocate temporary buffers for values, filled-switch, and coordinates. 652 // We do not use stack buffers for this, since the expanded size may 653 // be rather large (as it envelops a single expanded dense dimension). 654 Value values = genAlloc(rewriter, loc, sz, eltType); 655 Value filled = genAlloc(rewriter, loc, sz, boolType); 656 Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType); 657 Value zero = constantZero(rewriter, loc, idxType); 658 // Reset the values/filled-switch to all-zero/false. Note that this 659 // introduces an O(N) operation into the computation, but this reset 660 // operation is amortized over the innermost loops for the access 661 // pattern expansion. As noted in the operation doc, we would like 662 // to amortize this setup cost even between kernels. 663 rewriter.create<linalg::FillOp>( 664 loc, ValueRange{constantZero(rewriter, loc, eltType)}, 665 ValueRange{values}); 666 rewriter.create<linalg::FillOp>( 667 loc, ValueRange{constantZero(rewriter, loc, boolType)}, 668 ValueRange{filled}); 669 // Replace expansion op with these buffers and initial coordinate. 670 assert(op.getNumResults() == 4); 671 rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero}); 672 return success(); 673 } 674 }; 675 676 /// Sparse conversion rule for the compress operator. 677 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 678 public: 679 using OpConversionPattern::OpConversionPattern; 680 LogicalResult 681 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 682 ConversionPatternRewriter &rewriter) const override { 683 Location loc = op->getLoc(); 684 // Note that this method call resets the values/filled-switch back to 685 // all-zero/false by only iterating over the set elements, so the 686 // complexity remains proportional to the sparsity of the expanded 687 // access pattern. 688 Value values = adaptor.getValues(); 689 Value filled = adaptor.getFilled(); 690 Value added = adaptor.getAdded(); 691 Value count = adaptor.getCount(); 692 Value tensor = adaptor.getTensor(); 693 const auto stt = getSparseTensorType(op.getTensor()); 694 const Type elemTp = stt.getElementType(); 695 const Level lvlRank = stt.getLvlRank(); 696 auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 697 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords()); 698 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 699 createFuncCall(rewriter, loc, name, {}, 700 {tensor, lvlCoords, values, filled, added, count}, 701 EmitCInterface::On); 702 rewriter.replaceOp(op, adaptor.getTensor()); 703 // Deallocate the buffers on exit of the loop nest. 704 Operation *parent = getTop(op); 705 rewriter.setInsertionPointAfter(parent); 706 rewriter.create<memref::DeallocOp>(loc, values); 707 rewriter.create<memref::DeallocOp>(loc, filled); 708 rewriter.create<memref::DeallocOp>(loc, added); 709 return success(); 710 } 711 }; 712 713 /// Sparse conversion rule for the output operator. 714 class SparseTensorOutConverter : public OpConversionPattern<OutOp> { 715 public: 716 using OpConversionPattern::OpConversionPattern; 717 LogicalResult 718 matchAndRewrite(OutOp op, OpAdaptor adaptor, 719 ConversionPatternRewriter &rewriter) const override { 720 const Location loc = op->getLoc(); 721 const auto srcTp = getSparseTensorType(op.getTensor()); 722 // Convert to default permuted COO. 723 Value src = adaptor.getOperands()[0]; 724 SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src); 725 Value coo = NewCallParams(rewriter, loc) 726 .genBuffers(srcTp.withoutDimToLvl(), dimSizes) 727 .genNewCall(Action::kToCOO, src); 728 // Then output the tensor to external file with coordinates in the 729 // externally visible lexicographic coordinate order. A sort is 730 // required if the source was not in that order yet (note that the 731 // sort can be dropped altogether if external format does not care 732 // about the order at all, but here we assume it does). 733 const Value sort = constantI1(rewriter, loc, !srcTp.isIdentity()); 734 SmallVector<Value, 3> outParams{coo, adaptor.getOperands()[1], sort}; 735 const Type elemTp = srcTp.getElementType(); 736 SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(elemTp)}; 737 createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off); 738 genDelCOOCall(rewriter, loc, elemTp, coo); 739 rewriter.eraseOp(op); 740 return success(); 741 } 742 }; 743 744 /// Sparse conversion rule for the sparse_tensor.pack operator. 745 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> { 746 public: 747 using OpConversionPattern::OpConversionPattern; 748 LogicalResult 749 matchAndRewrite(AssembleOp op, OpAdaptor adaptor, 750 ConversionPatternRewriter &rewriter) const override { 751 const Location loc = op->getLoc(); 752 const auto dstTp = getSparseTensorType(op.getResult()); 753 // AssembleOps always returns a static shaped tensor result. 754 assert(dstTp.hasStaticDimShape()); 755 SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp); 756 Value dst = 757 NewCallParams(rewriter, loc) 758 .genBuffers(dstTp.withoutDimToLvl(), dimSizes) 759 .genNewCall(Action::kPack, 760 genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(), 761 adaptor.getValues())); 762 rewriter.replaceOp(op, dst); 763 return success(); 764 } 765 }; 766 767 } // namespace 768 769 //===----------------------------------------------------------------------===// 770 // Sparse tensor type conversion into opaque pointer. 771 //===----------------------------------------------------------------------===// 772 773 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() { 774 addConversion([](Type type) { return type; }); 775 addConversion(convertSparseTensorTypes); 776 } 777 778 //===----------------------------------------------------------------------===// 779 // Public method for populating conversion rules. 780 //===----------------------------------------------------------------------===// 781 782 /// Populates the given patterns list with conversion rules required for 783 /// the sparsification of linear algebra operations. 784 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, 785 RewritePatternSet &patterns) { 786 patterns 787 .add<SparseReturnConverter, SparseTensorToDimSizeConverter, 788 SparseCastConverter, SparseTensorNewConverter, 789 SparseTensorAllocConverter, SparseTensorEmptyConverter, 790 SparseTensorDeallocConverter, SparseTensorReorderCOOConverter, 791 SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter, 792 SparseTensorToValuesConverter, SparseNumberOfEntriesConverter, 793 SparseTensorLoadConverter, SparseTensorInsertConverter, 794 SparseTensorExpandConverter, SparseTensorCompressConverter, 795 SparseTensorOutConverter, SparseTensorAssembleConverter>( 796 typeConverter, patterns.getContext()); 797 } 798