1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts sparse tensor primitives into calls into a runtime 10 // support library. Sparse tensor types are converted into opaque pointers 11 // to the underlying sparse storage schemes. The use of opaque pointers 12 // together with runtime support library keeps the conversion relatively 13 // simple, but at the expense of IR opacity, which obscures opportunities 14 // for subsequent optimization of the IR. An alternative is provided by 15 // the SparseTensorCodegen pass. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "Utils/CodegenUtils.h" 20 21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 23 #include "mlir/Dialect/Linalg/Utils/Utils.h" 24 #include "mlir/Dialect/MemRef/IR/MemRef.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 30 #include "mlir/Dialect/Tensor/IR/Tensor.h" 31 #include "mlir/Transforms/DialectConversion.h" 32 33 using namespace mlir; 34 using namespace mlir::sparse_tensor; 35 36 namespace { 37 38 //===----------------------------------------------------------------------===// 39 // Helper methods. 40 //===----------------------------------------------------------------------===// 41 42 /// Maps each sparse tensor type to an opaque pointer. 43 static std::optional<Type> convertSparseTensorTypes(Type type) { 44 if (getSparseTensorEncoding(type) != nullptr) 45 return LLVM::LLVMPointerType::get(type.getContext()); 46 return std::nullopt; 47 } 48 49 /// Generates call to lookup a level-size. N.B., this only generates 50 /// the raw function call, and therefore (intentionally) does not perform 51 /// any dim<->lvl conversion or other logic. 52 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor, 53 uint64_t lvl) { 54 StringRef name = "sparseLvlSize"; 55 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)}; 56 Type iTp = builder.getIndexType(); 57 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 58 .getResult(0); 59 } 60 61 /// Generates call to lookup a dimension-size. N.B., this only generates 62 /// the raw function call, and therefore (intentionally) does not perform 63 /// any dim<->lvl conversion or other logic. 64 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor, 65 uint64_t dim) { 66 StringRef name = "sparseDimSize"; 67 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)}; 68 Type iTp = builder.getIndexType(); 69 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 70 .getResult(0); 71 } 72 73 /// Looks up a level-size by returning a statically-computed constant 74 /// (when possible), or by calling `genLvlSizeCall` (when dynamic). 75 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc, 76 SparseTensorType stt, Value tensor, 77 Level lvl) { 78 // Only sparse tensors have "levels" to query. 79 assert(stt.hasEncoding()); 80 // TODO: The following implementation only handles permutations; 81 // we'll need to generalize this to handle arbitrary AffineExpr. 82 // 83 // There's no need to assert `isPermutation` here: because 84 // `getDimPosition` checks that the expr isa `AffineDimExpr`, 85 // which is all we care about (for supporting permutations). 86 const Dimension dim = 87 stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl); 88 const Size sz = stt.getDynamicDimSize(dim); 89 if (!ShapedType::isDynamic(sz)) 90 return constantIndex(builder, loc, sz); 91 // If we cannot statically compute the size from the shape, then we 92 // must dynamically query it. (In principle we could also dynamically 93 // compute it, but since we already did so to construct the `tensor` 94 // in the first place, we might as well query rather than recompute.) 95 return genLvlSizeCall(builder, loc, tensor, lvl); 96 } 97 98 /// Looks up a dimension-size by returning a constant from the shape 99 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes 100 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes 101 /// of dense tensors). 102 static Value createOrFoldDimCall(OpBuilder &builder, Location loc, 103 SparseTensorType stt, Value tensor, 104 Dimension dim) { 105 const Size sz = stt.getDynamicDimSize(dim); 106 if (!ShapedType::isDynamic(sz)) 107 return constantIndex(builder, loc, sz); 108 if (stt.hasEncoding()) 109 return genDimSizeCall(builder, loc, tensor, dim); 110 return linalg::createOrFoldDimOp(builder, loc, tensor, dim); 111 } 112 113 /// Populates the array with the dimension-sizes of the given tensor. 114 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, 115 Value tensor, SmallVectorImpl<Value> &out) { 116 const Dimension dimRank = stt.getDimRank(); 117 out.clear(); 118 out.reserve(dimRank); 119 for (Dimension d = 0; d < dimRank; d++) 120 out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d)); 121 } 122 123 /// Returns an array with the dimension-sizes of the given tensor. 124 /// If the *tensor* parameters is null, the tensor type is assumed to have a 125 /// static shape. 126 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc, 127 SparseTensorType stt, 128 Value tensor = Value()) { 129 SmallVector<Value> out; 130 fillDimSizes(builder, loc, stt, tensor, out); 131 return out; 132 } 133 134 /// Generates an uninitialized buffer of the given size and type, 135 /// but returns it as type `memref<? x $tp>` (rather than as type 136 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, 137 /// this buffer must be explicitly deallocated by client. 138 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { 139 auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); 140 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); 141 } 142 143 /// Generates a temporary buffer for the level-types of the given encoding. 144 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, 145 SparseTensorType stt) { 146 SmallVector<Value> lvlTypes; 147 lvlTypes.reserve(stt.getLvlRank()); 148 for (const auto lt : stt.getEncoding().getLvlTypes()) 149 lvlTypes.push_back(constantLevelTypeEncoding(builder, loc, lt)); 150 return allocaBuffer(builder, loc, lvlTypes); 151 } 152 153 /// Extracts the bare (aligned) pointers that point to the tensor. 154 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc, 155 Value tensor) { 156 auto buf = genToMemref(builder, loc, tensor); 157 return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf); 158 } 159 160 /// Generates a temporary buffer for the level-types of the given encoding. 161 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc, 162 ValueRange lvlTensors, Value valTensor) { 163 SmallVector<Value> lvlBarePtrs; 164 lvlBarePtrs.reserve(lvlTensors.size() + 1); 165 // Passing in lvl buffer pointers. 166 for (const auto lvl : lvlTensors) 167 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl)); 168 169 // Passing in value buffer pointers. 170 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor)); 171 Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>( 172 loc, allocaBuffer(builder, loc, lvlBarePtrs)); 173 Value idxCast = 174 builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr); 175 return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder), 176 idxCast); 177 } 178 179 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`: 180 /// the "swiss army knife" method of the sparse runtime support library 181 /// for materializing sparse tensors into the computation. This abstraction 182 /// reduces the need for modifications when the API changes. 183 class NewCallParams final { 184 public: 185 /// Allocates the `ValueRange` for the `func::CallOp` parameters. 186 NewCallParams(OpBuilder &builder, Location loc) 187 : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {} 188 189 /// Initializes all static parameters (i.e., those which indicate 190 /// type-level information such as the encoding and sizes), generating 191 /// MLIR buffers as needed, and returning `this` for method chaining. 192 NewCallParams &genBuffers(SparseTensorType stt, 193 ArrayRef<Value> dimSizesValues, 194 Value dimSizesBuffer = Value()) { 195 assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank())); 196 // Sparsity annotations. 197 params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt); 198 // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers. 199 params[kParamDimSizes] = dimSizesBuffer 200 ? dimSizesBuffer 201 : allocaBuffer(builder, loc, dimSizesValues); 202 SmallVector<Value> lvlSizesValues; // unused 203 params[kParamLvlSizes] = genMapBuffers( 204 builder, loc, stt, dimSizesValues, params[kParamDimSizes], 205 lvlSizesValues, params[kParamDim2Lvl], params[kParamLvl2Dim]); 206 // Secondary and primary types encoding. 207 const auto enc = stt.getEncoding(); 208 params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc); 209 params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc); 210 params[kParamValTp] = 211 constantPrimaryTypeEncoding(builder, loc, stt.getElementType()); 212 // Return `this` for method chaining. 213 return *this; 214 } 215 216 /// Checks whether all the static parameters have been initialized. 217 bool isInitialized() const { 218 for (unsigned i = 0; i < kNumStaticParams; ++i) 219 if (!params[i]) 220 return false; 221 return true; 222 } 223 224 /// Generates a function call, with the current static parameters 225 /// and the given dynamic arguments. 226 Value genNewCall(Action action, Value ptr = Value()) { 227 assert(isInitialized() && "Must initialize before genNewCall"); 228 StringRef name = "newSparseTensor"; 229 params[kParamAction] = constantAction(builder, loc, action); 230 params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp); 231 return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On) 232 .getResult(0); 233 } 234 235 private: 236 static constexpr unsigned kNumStaticParams = 8; 237 static constexpr unsigned kNumDynamicParams = 2; 238 static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams; 239 static constexpr unsigned kParamDimSizes = 0; 240 static constexpr unsigned kParamLvlSizes = 1; 241 static constexpr unsigned kParamLvlTypes = 2; 242 static constexpr unsigned kParamDim2Lvl = 3; 243 static constexpr unsigned kParamLvl2Dim = 4; 244 static constexpr unsigned kParamPosTp = 5; 245 static constexpr unsigned kParamCrdTp = 6; 246 static constexpr unsigned kParamValTp = 7; 247 static constexpr unsigned kParamAction = 8; 248 static constexpr unsigned kParamPtr = 9; 249 250 OpBuilder &builder; 251 Location loc; 252 Type pTp; 253 Value params[kNumParams]; 254 }; 255 256 /// Generates a call to obtain the values array. 257 static Value genValuesCall(OpBuilder &builder, Location loc, 258 SparseTensorType stt, Value ptr) { 259 auto eltTp = stt.getElementType(); 260 auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp); 261 SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)}; 262 return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On) 263 .getResult(0); 264 } 265 266 /// Generates a call to obtain the positions array. 267 static Value genPositionsCall(OpBuilder &builder, Location loc, 268 SparseTensorType stt, Value ptr, Level l) { 269 Type posTp = stt.getPosType(); 270 auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp); 271 Value lvl = constantIndex(builder, loc, l); 272 SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)}; 273 return createFuncCall(builder, loc, name, resTp, {ptr, lvl}, 274 EmitCInterface::On) 275 .getResult(0); 276 } 277 278 /// Generates a call to obtain the coordinates array. 279 static Value genCoordinatesCall(OpBuilder &builder, Location loc, 280 SparseTensorType stt, Value ptr, Level l) { 281 Type crdTp = stt.getCrdType(); 282 auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp); 283 Value lvl = constantIndex(builder, loc, l); 284 SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)}; 285 return createFuncCall(builder, loc, name, resTp, {ptr, lvl}, 286 EmitCInterface::On) 287 .getResult(0); 288 } 289 290 /// Generates a call to obtain the coordinates array (AoS view). 291 static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc, 292 SparseTensorType stt, Value ptr, 293 Level l) { 294 Type crdTp = stt.getCrdType(); 295 auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp); 296 Value lvl = constantIndex(builder, loc, l); 297 SmallString<25> name{"sparseCoordinatesBuffer", 298 overheadTypeFunctionSuffix(crdTp)}; 299 return createFuncCall(builder, loc, name, resTp, {ptr, lvl}, 300 EmitCInterface::On) 301 .getResult(0); 302 } 303 304 //===----------------------------------------------------------------------===// 305 // Conversion rules. 306 //===----------------------------------------------------------------------===// 307 308 /// Sparse conversion rule for returns. 309 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 310 public: 311 using OpConversionPattern::OpConversionPattern; 312 LogicalResult 313 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 314 ConversionPatternRewriter &rewriter) const override { 315 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 316 return success(); 317 } 318 }; 319 320 /// Sparse conversion rule for accessing level-sizes. 321 class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> { 322 public: 323 using OpConversionPattern::OpConversionPattern; 324 LogicalResult 325 matchAndRewrite(LvlOp op, OpAdaptor adaptor, 326 ConversionPatternRewriter &rewriter) const override { 327 const auto stt = getSparseTensorType(op.getSource()); 328 // Only rewrite sparse DimOp. 329 if (!stt.hasEncoding()) 330 return failure(); 331 332 // Only rewrite DimOp with constant index. 333 std::optional<int64_t> lvl = op.getConstantLvlIndex(); 334 335 if (!lvl) 336 return failure(); 337 338 // By now, if the level size is constant, the operation should have already 339 // been folded by LvlOp's folder, so we generate the call unconditionally. 340 Value src = adaptor.getOperands()[0]; 341 rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl)); 342 return success(); 343 } 344 }; 345 346 /// Sparse conversion rule for trivial tensor casts. 347 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 348 public: 349 using OpConversionPattern::OpConversionPattern; 350 LogicalResult 351 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 352 ConversionPatternRewriter &rewriter) const override { 353 // Only rewrite identically annotated source/dest. 354 auto encDst = getSparseTensorEncoding(op.getType()); 355 auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 356 if (!encDst || encDst != encSrc) 357 return failure(); 358 rewriter.replaceOp(op, adaptor.getOperands()); 359 return success(); 360 } 361 }; 362 363 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> { 364 public: 365 using OpConversionPattern::OpConversionPattern; 366 LogicalResult 367 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor, 368 ConversionPatternRewriter &rewriter) const override { 369 // Simply fold the operation. 370 rewriter.replaceOp(op, adaptor.getSource()); 371 return success(); 372 } 373 }; 374 375 /// Sparse conversion rule for the new operator. 376 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 377 public: 378 using OpConversionPattern::OpConversionPattern; 379 LogicalResult 380 matchAndRewrite(NewOp op, OpAdaptor adaptor, 381 ConversionPatternRewriter &rewriter) const override { 382 Location loc = op.getLoc(); 383 const auto stt = getSparseTensorType(op); 384 if (!stt.hasEncoding()) 385 return failure(); 386 // Construct the `reader` opening method calls. 387 SmallVector<Value> dimSizesValues; 388 Value dimSizesBuffer; 389 Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0], 390 dimSizesValues, dimSizesBuffer); 391 // Use the `reader` to parse the file. 392 Value tensor = NewCallParams(rewriter, loc) 393 .genBuffers(stt, dimSizesValues, dimSizesBuffer) 394 .genNewCall(Action::kFromReader, reader); 395 // Free the memory for `reader`. 396 createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, 397 EmitCInterface::Off); 398 rewriter.replaceOp(op, tensor); 399 return success(); 400 } 401 }; 402 403 /// Sparse conversion rule for the alloc operator. 404 /// TODO(springerm): remove when bufferization.alloc_tensor is gone 405 class SparseTensorAllocConverter 406 : public OpConversionPattern<bufferization::AllocTensorOp> { 407 public: 408 using OpConversionPattern::OpConversionPattern; 409 LogicalResult 410 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, 411 ConversionPatternRewriter &rewriter) const override { 412 const auto stt = getSparseTensorType(op); 413 if (!stt.hasEncoding()) 414 return failure(); 415 if (op.getCopy()) 416 return rewriter.notifyMatchFailure(op, "alloc copy not implemented"); 417 // Gather all dimension sizes as SSA values. 418 Location loc = op.getLoc(); 419 const Dimension dimRank = stt.getDimRank(); 420 SmallVector<Value> dimSizesValues; 421 dimSizesValues.reserve(dimRank); 422 unsigned operandCtr = 0; 423 for (Dimension d = 0; d < dimRank; d++) { 424 dimSizesValues.push_back( 425 stt.isDynamicDim(d) 426 ? adaptor.getOperands()[operandCtr++] 427 : constantIndex(rewriter, loc, op.getStaticSize(d))); 428 } 429 // Generate the call to construct empty tensor. The sizes are 430 // explicitly defined by the arguments to the alloc operator. 431 rewriter.replaceOp(op, NewCallParams(rewriter, loc) 432 .genBuffers(stt, dimSizesValues) 433 .genNewCall(Action::kEmpty)); 434 return success(); 435 } 436 }; 437 438 /// Sparse conversion rule for the empty tensor. 439 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { 440 public: 441 using OpConversionPattern::OpConversionPattern; 442 LogicalResult 443 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, 444 ConversionPatternRewriter &rewriter) const override { 445 Location loc = op.getLoc(); 446 const auto stt = getSparseTensorType(op); 447 if (!stt.hasEncoding()) 448 return failure(); 449 // Gather all dimension sizes as SSA values. 450 const Dimension dimRank = stt.getDimRank(); 451 SmallVector<Value> dimSizesValues; 452 dimSizesValues.reserve(dimRank); 453 auto shape = op.getType().getShape(); 454 unsigned operandCtr = 0; 455 for (Dimension d = 0; d < dimRank; d++) { 456 dimSizesValues.push_back(stt.isDynamicDim(d) 457 ? adaptor.getOperands()[operandCtr++] 458 : constantIndex(rewriter, loc, shape[d])); 459 } 460 // Generate the call to construct empty tensor. The sizes are 461 // explicitly defined by the arguments to the alloc operator. 462 rewriter.replaceOp(op, NewCallParams(rewriter, loc) 463 .genBuffers(stt, dimSizesValues) 464 .genNewCall(Action::kEmpty)); 465 return success(); 466 } 467 }; 468 469 /// Sparse conversion rule for the convert operator. 470 class SparseTensorReorderCOOConverter 471 : public OpConversionPattern<ReorderCOOOp> { 472 public: 473 using OpConversionPattern::OpConversionPattern; 474 475 LogicalResult 476 matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor, 477 ConversionPatternRewriter &rewriter) const override { 478 const Location loc = op->getLoc(); 479 const auto srcTp = getSparseTensorType(op.getInputCoo()); 480 const auto dstTp = getSparseTensorType(op); 481 482 const Value src = adaptor.getInputCoo(); 483 484 NewCallParams params(rewriter, loc); 485 SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, srcTp, src); 486 rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizesValues) 487 .genNewCall(Action::kSortCOOInPlace, src)); 488 489 return success(); 490 } 491 }; 492 493 /// Sparse conversion rule for the dealloc operator. 494 class SparseTensorDeallocConverter 495 : public OpConversionPattern<bufferization::DeallocTensorOp> { 496 public: 497 using OpConversionPattern::OpConversionPattern; 498 LogicalResult 499 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, 500 ConversionPatternRewriter &rewriter) const override { 501 if (!getSparseTensorType(op.getTensor()).hasEncoding()) 502 return failure(); 503 StringRef name = "delSparseTensor"; 504 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 505 EmitCInterface::Off); 506 rewriter.eraseOp(op); 507 return success(); 508 } 509 }; 510 511 /// Sparse conversion rule for position accesses. 512 class SparseTensorToPositionsConverter 513 : public OpConversionPattern<ToPositionsOp> { 514 public: 515 using OpConversionPattern::OpConversionPattern; 516 LogicalResult 517 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor, 518 ConversionPatternRewriter &rewriter) const override { 519 auto stt = getSparseTensorType(op.getTensor()); 520 auto poss = genPositionsCall(rewriter, op.getLoc(), stt, 521 adaptor.getTensor(), op.getLevel()); 522 rewriter.replaceOp(op, poss); 523 return success(); 524 } 525 }; 526 527 /// Sparse conversion rule for coordinate accesses. 528 class SparseTensorToCoordinatesConverter 529 : public OpConversionPattern<ToCoordinatesOp> { 530 public: 531 using OpConversionPattern::OpConversionPattern; 532 LogicalResult 533 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor, 534 ConversionPatternRewriter &rewriter) const override { 535 const Location loc = op.getLoc(); 536 auto stt = getSparseTensorType(op.getTensor()); 537 auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), 538 op.getLevel()); 539 // Cast the MemRef type to the type expected by the users, though these 540 // two types should be compatible at runtime. 541 if (op.getType() != crds.getType()) 542 crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds); 543 rewriter.replaceOp(op, crds); 544 return success(); 545 } 546 }; 547 548 /// Sparse conversion rule for coordinate accesses (AoS style). 549 class SparseToCoordinatesBufferConverter 550 : public OpConversionPattern<ToCoordinatesBufferOp> { 551 public: 552 using OpConversionPattern::OpConversionPattern; 553 LogicalResult 554 matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor, 555 ConversionPatternRewriter &rewriter) const override { 556 const Location loc = op.getLoc(); 557 auto stt = getSparseTensorType(op.getTensor()); 558 auto crds = genCoordinatesBufferCall( 559 rewriter, loc, stt, adaptor.getTensor(), stt.getAoSCOOStart()); 560 // Cast the MemRef type to the type expected by the users, though these 561 // two types should be compatible at runtime. 562 if (op.getType() != crds.getType()) 563 crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds); 564 rewriter.replaceOp(op, crds); 565 return success(); 566 } 567 }; 568 569 /// Sparse conversion rule for value accesses. 570 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 571 public: 572 using OpConversionPattern::OpConversionPattern; 573 LogicalResult 574 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 575 ConversionPatternRewriter &rewriter) const override { 576 auto stt = getSparseTensorType(op.getTensor()); 577 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor()); 578 rewriter.replaceOp(op, vals); 579 return success(); 580 } 581 }; 582 583 /// Sparse conversion rule for number of entries operator. 584 class SparseNumberOfEntriesConverter 585 : public OpConversionPattern<NumberOfEntriesOp> { 586 public: 587 using OpConversionPattern::OpConversionPattern; 588 LogicalResult 589 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, 590 ConversionPatternRewriter &rewriter) const override { 591 // Query values array size for the actually stored values size. 592 auto stt = getSparseTensorType(op.getTensor()); 593 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor()); 594 auto zero = constantIndex(rewriter, op.getLoc(), 0); 595 rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero); 596 return success(); 597 } 598 }; 599 600 /// Sparse conversion rule for tensor rematerialization. 601 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 602 public: 603 using OpConversionPattern::OpConversionPattern; 604 LogicalResult 605 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 606 ConversionPatternRewriter &rewriter) const override { 607 if (op.getHasInserts()) { 608 // Finalize any pending insertions. 609 StringRef name = "endLexInsert"; 610 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 611 EmitCInterface::Off); 612 } 613 rewriter.replaceOp(op, adaptor.getOperands()); 614 return success(); 615 } 616 }; 617 618 /// Sparse conversion rule for the insertion operator. 619 class SparseTensorInsertConverter 620 : public OpConversionPattern<tensor::InsertOp> { 621 public: 622 using OpConversionPattern::OpConversionPattern; 623 LogicalResult 624 matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor, 625 ConversionPatternRewriter &rewriter) const override { 626 // Note that the current regime only allows for strict lexicographic 627 // coordinate order. All values are passed by reference through stack 628 // allocated memrefs. 629 Location loc = op->getLoc(); 630 const auto stt = getSparseTensorType(op.getDest()); 631 632 // Dense tensor insertion. 633 if (!stt.hasEncoding()) 634 return failure(); 635 636 assert(stt.isIdentity() && "Run reinterpret-map before conversion."); 637 const auto elemTp = stt.getElementType(); 638 const Level lvlRank = stt.getLvlRank(); 639 Value lvlCoords, vref; 640 { 641 OpBuilder::InsertionGuard guard(rewriter); 642 Operation *loop = op; 643 // Finds the outermost loop. 644 while (auto l = loop->getParentOfType<LoopLikeOpInterface>()) 645 loop = l; 646 647 if (llvm::isa<LoopLikeOpInterface>(loop)) { 648 // Hoists alloca outside the loop to avoid stack overflow. 649 rewriter.setInsertionPoint(loop); 650 } 651 lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 652 vref = genAllocaScalar(rewriter, loc, elemTp); 653 } 654 storeAll(rewriter, loc, lvlCoords, adaptor.getIndices()); 655 rewriter.create<memref::StoreOp>(loc, adaptor.getScalar(), vref); 656 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 657 createFuncCall(rewriter, loc, name, {}, 658 {adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On); 659 rewriter.replaceOp(op, adaptor.getDest()); 660 return success(); 661 } 662 }; 663 664 /// Sparse conversion rule for the expand operator. 665 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 666 public: 667 using OpConversionPattern::OpConversionPattern; 668 LogicalResult 669 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 670 ConversionPatternRewriter &rewriter) const override { 671 Location loc = op->getLoc(); 672 const auto srcTp = getSparseTensorType(op.getTensor()); 673 Type eltType = srcTp.getElementType(); 674 Type boolType = rewriter.getIntegerType(1); 675 Type idxType = rewriter.getIndexType(); 676 // All initialization should be done on entry of the loop nest. 677 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); 678 // Get the cardinality of valid coordinates for the innermost level. 679 Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(), 680 srcTp.getLvlRank() - 1); 681 // Allocate temporary buffers for values, filled-switch, and coordinates. 682 // We do not use stack buffers for this, since the expanded size may 683 // be rather large (as it envelops a single expanded dense dimension). 684 Value values = genAlloc(rewriter, loc, sz, eltType); 685 Value filled = genAlloc(rewriter, loc, sz, boolType); 686 Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType); 687 Value zero = constantZero(rewriter, loc, idxType); 688 // Reset the values/filled-switch to all-zero/false. Note that this 689 // introduces an O(N) operation into the computation, but this reset 690 // operation is amortized over the innermost loops for the access 691 // pattern expansion. As noted in the operation doc, we would like 692 // to amortize this setup cost even between kernels. 693 rewriter.create<linalg::FillOp>( 694 loc, ValueRange{constantZero(rewriter, loc, eltType)}, 695 ValueRange{values}); 696 rewriter.create<linalg::FillOp>( 697 loc, ValueRange{constantZero(rewriter, loc, boolType)}, 698 ValueRange{filled}); 699 // Replace expansion op with these buffers and initial coordinate. 700 assert(op.getNumResults() == 4); 701 rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero}); 702 return success(); 703 } 704 }; 705 706 /// Sparse conversion rule for the compress operator. 707 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 708 public: 709 using OpConversionPattern::OpConversionPattern; 710 LogicalResult 711 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 712 ConversionPatternRewriter &rewriter) const override { 713 Location loc = op->getLoc(); 714 // Note that this method call resets the values/filled-switch back to 715 // all-zero/false by only iterating over the set elements, so the 716 // complexity remains proportional to the sparsity of the expanded 717 // access pattern. 718 Value values = adaptor.getValues(); 719 Value filled = adaptor.getFilled(); 720 Value added = adaptor.getAdded(); 721 Value count = adaptor.getCount(); 722 Value tensor = adaptor.getTensor(); 723 const auto stt = getSparseTensorType(op.getTensor()); 724 const Type elemTp = stt.getElementType(); 725 const Level lvlRank = stt.getLvlRank(); 726 auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 727 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords()); 728 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 729 createFuncCall(rewriter, loc, name, {}, 730 {tensor, lvlCoords, values, filled, added, count}, 731 EmitCInterface::On); 732 rewriter.replaceOp(op, adaptor.getTensor()); 733 // Deallocate the buffers on exit of the loop nest. 734 Operation *parent = getTop(op); 735 rewriter.setInsertionPointAfter(parent); 736 rewriter.create<memref::DeallocOp>(loc, values); 737 rewriter.create<memref::DeallocOp>(loc, filled); 738 rewriter.create<memref::DeallocOp>(loc, added); 739 return success(); 740 } 741 }; 742 743 /// Sparse conversion rule for the sparse_tensor.assemble operator. 744 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> { 745 public: 746 using OpConversionPattern::OpConversionPattern; 747 LogicalResult 748 matchAndRewrite(AssembleOp op, OpAdaptor adaptor, 749 ConversionPatternRewriter &rewriter) const override { 750 const Location loc = op->getLoc(); 751 const auto dstTp = getSparseTensorType(op.getResult()); 752 assert(dstTp.hasStaticDimShape()); 753 SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, dstTp); 754 // Use a library method to transfer the external buffers from 755 // clients to the internal SparseTensorStorage. Since we cannot 756 // assume clients transfer ownership of the buffers, this method 757 // will copy all data over into a new SparseTensorStorage. 758 Value dst = 759 NewCallParams(rewriter, loc) 760 .genBuffers(dstTp.withoutDimToLvl(), dimSizesValues) 761 .genNewCall(Action::kPack, 762 genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(), 763 adaptor.getValues())); 764 rewriter.replaceOp(op, dst); 765 return success(); 766 } 767 }; 768 769 /// Sparse conversion rule for the sparse_tensor.disassemble operator. 770 /// Note that the current implementation simply exposes the buffers to 771 /// the external client. This assumes the client only reads the buffers 772 /// (usually copying it to the external data structures, such as numpy 773 /// arrays). The semantics of the disassemble operation technically 774 /// require that the copying is done here already using the out-levels 775 /// and out-values clause. 776 class SparseTensorDisassembleConverter 777 : public OpConversionPattern<DisassembleOp> { 778 public: 779 using OpConversionPattern::OpConversionPattern; 780 LogicalResult 781 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor, 782 ConversionPatternRewriter &rewriter) const override { 783 Location loc = op->getLoc(); 784 auto stt = getSparseTensorType(op.getTensor()); 785 SmallVector<Value> retVal; 786 SmallVector<Value> retLen; 787 // Get the positions and coordinates buffers. 788 const Level lvlRank = stt.getLvlRank(); 789 Level trailCOOLen = 0; 790 for (Level l = 0; l < lvlRank; l++) { 791 if (!stt.isUniqueLvl(l) && 792 (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) { 793 // A `(loose)compressed_nu` level marks the start of trailing COO 794 // start level. Since the target coordinate buffer used for trailing 795 // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA 796 // scheme, we cannot simply use the internal buffers. 797 trailCOOLen = lvlRank - l; 798 break; 799 } 800 if (stt.isWithPos(l)) { 801 auto poss = 802 genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l); 803 auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0); 804 auto posLenTp = op.getLvlLens().getTypes()[retLen.size()]; 805 retVal.push_back(poss); 806 retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp)); 807 } 808 if (stt.isWithCrd(l)) { 809 auto crds = 810 genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l); 811 auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0); 812 auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()]; 813 retVal.push_back(crds); 814 retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp)); 815 } 816 } 817 // Handle AoS vs. SoA mismatch for COO. 818 if (trailCOOLen != 0) { 819 uint64_t cooStartLvl = lvlRank - trailCOOLen; 820 assert(!stt.isUniqueLvl(cooStartLvl) && 821 (stt.isCompressedLvl(cooStartLvl) || 822 stt.isLooseCompressedLvl(cooStartLvl))); 823 // Positions. 824 auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), 825 cooStartLvl); 826 auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0); 827 auto posLenTp = op.getLvlLens().getTypes()[retLen.size()]; 828 retVal.push_back(poss); 829 retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp)); 830 // Coordinates, copied over with: 831 // for (i = 0; i < crdLen; i++) 832 // buf[i][0] = crd0[i]; buf[i][1] = crd1[i]; 833 auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]); 834 auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), 835 cooStartLvl); 836 auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), 837 cooStartLvl + 1); 838 auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0); 839 auto two = constantIndex(rewriter, loc, 2); 840 auto bufLen = rewriter.create<arith::MulIOp>(loc, crdLen, two); 841 Type indexType = rewriter.getIndexType(); 842 auto zero = constantZero(rewriter, loc, indexType); 843 auto one = constantOne(rewriter, loc, indexType); 844 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, zero, crdLen, one); 845 auto idx = forOp.getInductionVar(); 846 rewriter.setInsertionPointToStart(forOp.getBody()); 847 auto c0 = rewriter.create<memref::LoadOp>(loc, crds0, idx); 848 auto c1 = rewriter.create<memref::LoadOp>(loc, crds1, idx); 849 SmallVector<Value> args; 850 args.push_back(idx); 851 args.push_back(zero); 852 rewriter.create<memref::StoreOp>(loc, c0, buf, args); 853 args[1] = one; 854 rewriter.create<memref::StoreOp>(loc, c1, buf, args); 855 rewriter.setInsertionPointAfter(forOp); 856 auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()]; 857 retVal.push_back(buf); 858 retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp)); 859 } 860 // Get the values buffer last. 861 auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor()); 862 auto valLenTp = op.getValLen().getType(); 863 auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0); 864 retVal.push_back(vals); 865 retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp)); 866 867 // Converts MemRefs back to Tensors. 868 assert(retVal.size() + retLen.size() == op.getNumResults()); 869 for (unsigned i = 0, sz = retVal.size(); i < sz; i++) { 870 auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]); 871 retVal[i] = 872 rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor); 873 } 874 875 // Appends the actual memory length used in each buffer returned. 876 retVal.append(retLen.begin(), retLen.end()); 877 rewriter.replaceOp(op, retVal); 878 return success(); 879 } 880 }; 881 882 struct SparseHasRuntimeLibraryConverter 883 : public OpConversionPattern<HasRuntimeLibraryOp> { 884 using OpConversionPattern::OpConversionPattern; 885 LogicalResult 886 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor, 887 ConversionPatternRewriter &rewriter) const override { 888 auto i1Type = rewriter.getI1Type(); 889 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 890 op, i1Type, rewriter.getIntegerAttr(i1Type, 1)); 891 return success(); 892 } 893 }; 894 895 } // namespace 896 897 //===----------------------------------------------------------------------===// 898 // Sparse tensor type conversion into opaque pointer. 899 //===----------------------------------------------------------------------===// 900 901 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() { 902 addConversion([](Type type) { return type; }); 903 addConversion(convertSparseTensorTypes); 904 } 905 906 //===----------------------------------------------------------------------===// 907 // Public method for populating conversion rules. 908 //===----------------------------------------------------------------------===// 909 910 /// Populates the given patterns list with conversion rules required for 911 /// the sparsification of linear algebra operations. 912 void mlir::populateSparseTensorConversionPatterns( 913 const TypeConverter &typeConverter, RewritePatternSet &patterns) { 914 patterns 915 .add<SparseReturnConverter, SparseTensorLvlOpConverter, 916 SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter, 917 SparseTensorAllocConverter, SparseTensorEmptyConverter, 918 SparseTensorDeallocConverter, SparseTensorReorderCOOConverter, 919 SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter, 920 SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter, 921 SparseNumberOfEntriesConverter, SparseTensorLoadConverter, 922 SparseTensorInsertConverter, SparseTensorExpandConverter, 923 SparseTensorCompressConverter, SparseTensorAssembleConverter, 924 SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>( 925 typeConverter, patterns.getContext()); 926 } 927