1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts sparse tensor primitives into calls into a runtime 10 // support library. Sparse tensor types are converted into opaque pointers 11 // to the underlying sparse storage schemes. The use of opaque pointers 12 // together with runtime support library keeps the conversion relatively 13 // simple, but at the expense of IR opacity, which obscures opportunities 14 // for subsequent optimization of the IR. An alternative is provided by 15 // the SparseTensorCodegen pass. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "CodegenUtils.h" 20 21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 23 #include "mlir/Dialect/Linalg/Utils/Utils.h" 24 #include "mlir/Dialect/MemRef/IR/MemRef.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 30 #include "mlir/Dialect/Tensor/IR/Tensor.h" 31 #include "mlir/Transforms/DialectConversion.h" 32 33 using namespace mlir; 34 using namespace mlir::sparse_tensor; 35 36 namespace { 37 38 //===----------------------------------------------------------------------===// 39 // Helper methods. 40 //===----------------------------------------------------------------------===// 41 42 /// Maps each sparse tensor type to an opaque pointer. 43 static std::optional<Type> convertSparseTensorTypes(Type type) { 44 if (getSparseTensorEncoding(type) != nullptr) 45 return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); 46 return std::nullopt; 47 } 48 49 /// Replaces the `op` with a `CallOp` to the `getFunc()` function reference. 50 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op, 51 StringRef name, TypeRange resultType, 52 ValueRange operands, 53 EmitCInterface emitCInterface) { 54 auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands, 55 emitCInterface); 56 return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn, 57 operands); 58 } 59 60 /// Generates call to lookup a level-size. N.B., this only generates 61 /// the raw function call, and therefore (intentionally) does not perform 62 /// any dim<->lvl conversion or other logic. 63 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor, 64 uint64_t lvl) { 65 StringRef name = "sparseLvlSize"; 66 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)}; 67 Type iTp = builder.getIndexType(); 68 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 69 .getResult(0); 70 } 71 72 /// Generates call to lookup a dimension-size. N.B., this only generates 73 /// the raw function call, and therefore (intentionally) does not perform 74 /// any dim<->lvl conversion or other logic. 75 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor, 76 uint64_t dim) { 77 StringRef name = "sparseDimSize"; 78 SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)}; 79 Type iTp = builder.getIndexType(); 80 return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off) 81 .getResult(0); 82 } 83 84 /// Looks up a level-size by returning a statically-computed constant 85 /// (when possible), or by calling `genLvlSizeCall` (when dynamic). 86 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc, 87 SparseTensorType stt, Value tensor, 88 Level lvl) { 89 // Only sparse tensors have "levels" to query. 90 assert(stt.hasEncoding()); 91 // TODO: The following implementation only handles permutations; 92 // we'll need to generalize this to handle arbitrary AffineExpr. 93 // 94 // There's no need to assert `isPermutation` here: because 95 // `getDimPosition` checks that the expr isa `AffineDimExpr`, 96 // which is all we care about (for supporting permutations). 97 const Dimension dim = 98 stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl); 99 if (const auto sz = stt.getStaticDimSize(dim)) 100 return constantIndex(builder, loc, *sz); 101 // If we cannot statically compute the size from the shape, then we 102 // must dynamically query it. (In principle we could also dynamically 103 // compute it, but since we already did so to construct the `tensor` 104 // in the first place, we might as well query rather than recompute.) 105 return genLvlSizeCall(builder, loc, tensor, lvl); 106 } 107 108 /// Looks up a dimension-size by returning a constant from the shape 109 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes 110 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes 111 /// of dense tensors). 112 static Value createOrFoldDimCall(OpBuilder &builder, Location loc, 113 SparseTensorType stt, Value tensor, 114 Dimension dim) { 115 if (const auto sz = stt.getStaticDimSize(dim)) 116 return constantIndex(builder, loc, *sz); 117 if (stt.hasEncoding()) 118 return genDimSizeCall(builder, loc, tensor, dim); 119 return linalg::createOrFoldDimOp(builder, loc, tensor, dim); 120 } 121 122 /// Populates the array with the dimension-sizes of the given tensor. 123 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, 124 Value tensor, SmallVectorImpl<Value> &out) { 125 const Dimension dimRank = stt.getDimRank(); 126 out.clear(); 127 out.reserve(dimRank); 128 for (Dimension d = 0; d < dimRank; d++) 129 out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d)); 130 } 131 132 /// Returns an array with the dimension-sizes of the given tensor. 133 /// If the *tensor* parameters is null, the tensor type is assumed to have a 134 /// static shape. 135 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc, 136 SparseTensorType stt, 137 Value tensor = Value()) { 138 SmallVector<Value> out; 139 fillDimSizes(builder, loc, stt, tensor, out); 140 return out; 141 } 142 143 /// Generates an uninitialized buffer of the given size and type, 144 /// but returns it as type `memref<? x $tp>` (rather than as type 145 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, 146 /// this buffer must be explicitly deallocated by client. 147 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { 148 auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); 149 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); 150 } 151 152 /// Generates a temporary buffer for the level-types of the given encoding. 153 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, 154 SparseTensorType stt) { 155 SmallVector<Value> lvlTypes; 156 lvlTypes.reserve(stt.getLvlRank()); 157 for (const auto dlt : stt.getEncoding().getLvlTypes()) 158 lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt)); 159 return allocaBuffer(builder, loc, lvlTypes); 160 } 161 162 /// Extracts the bare (aligned) pointers that point to the tensor. 163 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc, 164 Value tensor) { 165 auto buf = genToMemref(builder, loc, tensor); 166 return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf); 167 } 168 169 /// Generates a temporary buffer for the level-types of the given encoding. 170 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc, 171 ValueRange lvlTensors, Value valTensor) { 172 SmallVector<Value> lvlBarePtrs; 173 lvlBarePtrs.reserve(lvlTensors.size() + 1); 174 // Passing in lvl buffer pointers. 175 for (const auto lvl : lvlTensors) 176 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl)); 177 178 // Passing in value buffer pointers. 179 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor)); 180 Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>( 181 loc, allocaBuffer(builder, loc, lvlBarePtrs)); 182 Value idxCast = 183 builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr); 184 return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder), 185 idxCast); 186 } 187 188 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`: 189 /// the "swiss army knife" method of the sparse runtime support library 190 /// for materializing sparse tensors into the computation. This abstraction 191 /// reduces the need for modifications when the API changes. 192 class NewCallParams final { 193 public: 194 /// Allocates the `ValueRange` for the `func::CallOp` parameters. 195 NewCallParams(OpBuilder &builder, Location loc) 196 : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {} 197 198 /// Initializes all static parameters (i.e., those which indicate 199 /// type-level information such as the encoding and sizes), generating 200 /// MLIR buffers as needed, and returning `this` for method chaining. 201 NewCallParams &genBuffers(SparseTensorType stt, 202 ArrayRef<Value> dimSizesValues) { 203 assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank())); 204 // Sparsity annotations. 205 params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt); 206 // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers. 207 params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizesValues); 208 params[kParamLvlSizes] = 209 genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes], 210 params[kParamDim2Lvl], params[kParamLvl2Dim]); 211 // Secondary and primary types encoding. 212 setTemplateTypes(stt); 213 // Finally, make note that initialization is complete. 214 assert(isInitialized() && "Initialization failed"); 215 // And return `this` for method chaining. 216 return *this; 217 } 218 219 /// (Re)sets the C++ template type parameters, and returns `this` 220 /// for method chaining. This is already done as part of `genBuffers`, 221 /// but is factored out so that it can also be called independently 222 /// whenever subsequent `genNewCall` calls want to reuse the same 223 /// buffers but different type parameters. 224 // 225 // TODO: This is only ever used by sparse2sparse-viaCOO `ConvertOp`; 226 // is there a better way to handle that than this one-off setter method? 227 NewCallParams &setTemplateTypes(SparseTensorType stt) { 228 const auto enc = stt.getEncoding(); 229 params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc); 230 params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc); 231 params[kParamValTp] = 232 constantPrimaryTypeEncoding(builder, loc, stt.getElementType()); 233 return *this; 234 } 235 236 /// Checks whether all the static parameters have been initialized. 237 bool isInitialized() const { 238 for (unsigned i = 0; i < kNumStaticParams; ++i) 239 if (!params[i]) 240 return false; 241 return true; 242 } 243 244 /// Gets the dimension-to-level mapping. 245 // 246 // TODO: This is only ever used for passing into `genAddEltCall`; 247 // is there a better way to encapsulate that pattern (both to avoid 248 // this one-off getter, and to avoid potential mixups)? 249 Value getDimToLvl() const { 250 assert(isInitialized() && "Must initialize before getDimToLvl"); 251 return params[kParamDim2Lvl]; 252 } 253 254 /// Generates a function call, with the current static parameters 255 /// and the given dynamic arguments. 256 Value genNewCall(Action action, Value ptr = Value()) { 257 assert(isInitialized() && "Must initialize before genNewCall"); 258 StringRef name = "newSparseTensor"; 259 params[kParamAction] = constantAction(builder, loc, action); 260 params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp); 261 return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On) 262 .getResult(0); 263 } 264 265 private: 266 static constexpr unsigned kNumStaticParams = 8; 267 static constexpr unsigned kNumDynamicParams = 2; 268 static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams; 269 static constexpr unsigned kParamDimSizes = 0; 270 static constexpr unsigned kParamLvlSizes = 1; 271 static constexpr unsigned kParamLvlTypes = 2; 272 static constexpr unsigned kParamDim2Lvl = 3; 273 static constexpr unsigned kParamLvl2Dim = 4; 274 static constexpr unsigned kParamPosTp = 5; 275 static constexpr unsigned kParamCrdTp = 6; 276 static constexpr unsigned kParamValTp = 7; 277 static constexpr unsigned kParamAction = 8; 278 static constexpr unsigned kParamPtr = 9; 279 280 OpBuilder &builder; 281 Location loc; 282 Type pTp; 283 Value params[kNumParams]; 284 }; 285 286 /// Generates a call to obtain the values array. 287 static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp, 288 ValueRange ptr) { 289 SmallString<15> name{"sparseValues", 290 primaryTypeFunctionSuffix(tp.getElementType())}; 291 return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On) 292 .getResult(0); 293 } 294 295 /// Generates a call to release/delete a `SparseTensorCOO`. 296 static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp, 297 Value coo) { 298 SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)}; 299 createFuncCall(builder, loc, name, {}, coo, EmitCInterface::Off); 300 } 301 302 /// Generates a call to release/delete a `SparseTensorIterator`. 303 static void genDelIteratorCall(OpBuilder &builder, Location loc, Type elemTp, 304 Value iter) { 305 SmallString<26> name{"delSparseTensorIterator", 306 primaryTypeFunctionSuffix(elemTp)}; 307 createFuncCall(builder, loc, name, {}, iter, EmitCInterface::Off); 308 } 309 310 /// Generates a call that adds one element to a coordinate scheme. 311 /// In particular, this generates code like the following: 312 /// val = a[i1,..,ik]; 313 /// if val != 0 314 /// t->add(&val, [i1,..,ik], [p1,..,pk]); 315 static void genAddEltCall(OpBuilder &builder, Location loc, Type eltType, 316 Value lvlCOO, Value valPtr, Value dimCoords, 317 Value dimToLvl) { 318 SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)}; 319 SmallVector<Value, 4> params{lvlCOO, valPtr, dimCoords, dimToLvl}; 320 Type pTp = getOpaquePointerType(builder); 321 createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On); 322 } 323 324 /// Generates a call to `iter->getNext()`. If there is a next element, 325 /// then it is copied into the out-parameters `coords` and `elemPtr`, 326 /// and the return value is true. If there isn't a next element, then 327 /// the return value is false. 328 /// 329 /// The `coords` argument uses the same coordinate-space as the `iter` 330 /// (which can be either dim- or lvl-coords, depending on context). 331 static Value genGetNextCall(OpBuilder &builder, Location loc, Value iter, 332 Value coords, Value elemPtr) { 333 Type elemTp = cast<ShapedType>(elemPtr.getType()).getElementType(); 334 SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)}; 335 SmallVector<Value, 3> params{iter, coords, elemPtr}; 336 Type i1 = builder.getI1Type(); 337 return createFuncCall(builder, loc, name, i1, params, EmitCInterface::On) 338 .getResult(0); 339 } 340 341 /// Loads the value stored in `elemPtr`, and stores it at the coordinates 342 /// `cvs` into a dense tensor created by `allocDenseTensor`. 343 static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc, 344 Value elemPtr, Value tensor, 345 ValueRange cvs) { 346 Value elemV = builder.create<memref::LoadOp>(loc, elemPtr); 347 builder.create<memref::StoreOp>(loc, elemV, tensor, cvs); 348 } 349 350 /// Determine if the runtime library supports direct conversion to the 351 /// given target `dimTypes`. 352 static bool canUseDirectConversion(ArrayRef<DimLevelType> dimTypes) { 353 bool alreadyCompressed = false; 354 for (const auto dlt : dimTypes) { 355 if (isCompressedDLT(dlt)) { 356 if (alreadyCompressed) 357 return false; // Multiple compressed dimensions not yet supported. 358 alreadyCompressed = true; 359 } else if (isDenseDLT(dlt)) { 360 if (alreadyCompressed) 361 return false; // Dense after Compressed not yet supported. 362 } else if (isSingletonDLT(dlt)) { 363 // Direct conversion doesn't have any particular problems with 364 // singleton after compressed. 365 } else { // TODO: investigate 366 return false; 367 } 368 } 369 return true; 370 } 371 372 //===----------------------------------------------------------------------===// 373 // Conversion rules. 374 //===----------------------------------------------------------------------===// 375 376 /// Sparse conversion rule for returns. 377 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 378 public: 379 using OpConversionPattern::OpConversionPattern; 380 LogicalResult 381 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 382 ConversionPatternRewriter &rewriter) const override { 383 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 384 return success(); 385 } 386 }; 387 388 /// Sparse conversion rule for accessing dimension-sizes. 389 class SparseTensorToDimSizeConverter 390 : public OpConversionPattern<tensor::DimOp> { 391 public: 392 using OpConversionPattern::OpConversionPattern; 393 LogicalResult 394 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 395 ConversionPatternRewriter &rewriter) const override { 396 const auto stt = getSparseTensorType(op.getSource()); 397 // Only rewrite sparse DimOp. 398 if (!stt.hasEncoding()) 399 return failure(); 400 // Only rewrite DimOp with constant index. 401 std::optional<int64_t> dim = op.getConstantIndex(); 402 if (!dim) 403 return failure(); 404 // Generate the call. 405 Value src = adaptor.getOperands()[0]; 406 rewriter.replaceOp( 407 op, createOrFoldDimCall(rewriter, op->getLoc(), stt, src, *dim)); 408 return success(); 409 } 410 }; 411 412 /// Sparse conversion rule for trivial tensor casts. 413 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 414 public: 415 using OpConversionPattern::OpConversionPattern; 416 LogicalResult 417 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 418 ConversionPatternRewriter &rewriter) const override { 419 // Only rewrite identically annotated source/dest. 420 auto encDst = getSparseTensorEncoding(op.getType()); 421 auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 422 if (!encDst || encDst != encSrc) 423 return failure(); 424 rewriter.replaceOp(op, adaptor.getOperands()); 425 return success(); 426 } 427 }; 428 429 /// Sparse conversion rule for the new operator. 430 class SparseTensorNewConverter : public OpConversionPattern<NewOp> { 431 public: 432 using OpConversionPattern::OpConversionPattern; 433 LogicalResult 434 matchAndRewrite(NewOp op, OpAdaptor adaptor, 435 ConversionPatternRewriter &rewriter) const override { 436 Location loc = op.getLoc(); 437 const auto stt = getSparseTensorType(op); 438 if (!stt.hasEncoding()) 439 return failure(); 440 // Construct the reader opening method calls. 441 SmallVector<Value> dimShapesValues; 442 Value dimSizesBuffer; 443 Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0], 444 dimShapesValues, dimSizesBuffer); 445 // Now construct the lvlSizes, dim2lvl, and lvl2dim buffers. 446 Value dim2lvlBuffer; 447 Value lvl2dimBuffer; 448 Value lvlSizesBuffer = 449 genMapBuffers(rewriter, loc, stt, dimShapesValues, dimSizesBuffer, 450 dim2lvlBuffer, lvl2dimBuffer); 451 // Use the `reader` to parse the file. 452 Type opaqueTp = getOpaquePointerType(rewriter); 453 Type eltTp = stt.getElementType(); 454 Value valTp = constantPrimaryTypeEncoding(rewriter, loc, eltTp); 455 SmallVector<Value, 8> params{ 456 reader, 457 lvlSizesBuffer, 458 genLvlTypesBuffer(rewriter, loc, stt), 459 dim2lvlBuffer, 460 lvl2dimBuffer, 461 constantPosTypeEncoding(rewriter, loc, stt.getEncoding()), 462 constantCrdTypeEncoding(rewriter, loc, stt.getEncoding()), 463 valTp}; 464 Value tensor = createFuncCall(rewriter, loc, "newSparseTensorFromReader", 465 opaqueTp, params, EmitCInterface::On) 466 .getResult(0); 467 // Free the memory for `reader`. 468 createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, 469 EmitCInterface::Off); 470 rewriter.replaceOp(op, tensor); 471 return success(); 472 } 473 }; 474 475 /// Sparse conversion rule for the alloc operator. 476 /// TODO(springerm): remove when bufferization.alloc_tensor is gone 477 class SparseTensorAllocConverter 478 : public OpConversionPattern<bufferization::AllocTensorOp> { 479 public: 480 using OpConversionPattern::OpConversionPattern; 481 LogicalResult 482 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, 483 ConversionPatternRewriter &rewriter) const override { 484 if (op.getCopy()) 485 return rewriter.notifyMatchFailure(op, 486 "sparse tensor copy not implemented"); 487 Location loc = op.getLoc(); 488 const auto stt = getSparseTensorType(op); 489 if (!stt.hasEncoding()) 490 return failure(); 491 // Gather all dimension sizes as SSA values. 492 const Dimension dimRank = stt.getDimRank(); 493 SmallVector<Value> dimSizes; 494 dimSizes.reserve(dimRank); 495 unsigned operandCtr = 0; 496 for (Dimension d = 0; d < dimRank; ++d) { 497 dimSizes.push_back( 498 stt.isDynamicDim(d) 499 ? adaptor.getOperands()[operandCtr++] 500 : constantIndex(rewriter, loc, op.getStaticSize(d))); 501 } 502 // Generate the call to construct empty tensor. The sizes are 503 // explicitly defined by the arguments to the alloc operator. 504 rewriter.replaceOp(op, NewCallParams(rewriter, loc) 505 .genBuffers(stt, dimSizes) 506 .genNewCall(Action::kEmpty)); 507 return success(); 508 } 509 }; 510 511 /// Sparse conversion rule for the empty tensor. 512 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { 513 public: 514 using OpConversionPattern::OpConversionPattern; 515 LogicalResult 516 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, 517 ConversionPatternRewriter &rewriter) const override { 518 Location loc = op.getLoc(); 519 const auto stt = getSparseTensorType(op); 520 if (!stt.hasEncoding()) 521 return failure(); 522 // Gather all dimension sizes as SSA values. 523 const Dimension dimRank = stt.getDimRank(); 524 SmallVector<Value> dimSizes; 525 dimSizes.reserve(dimRank); 526 auto shape = op.getType().getShape(); 527 unsigned operandCtr = 0; 528 for (Dimension d = 0; d < dimRank; ++d) { 529 dimSizes.push_back(stt.isDynamicDim(d) 530 ? adaptor.getOperands()[operandCtr++] 531 : constantIndex(rewriter, loc, shape[d])); 532 } 533 // Generate the call to construct empty tensor. The sizes are 534 // explicitly defined by the arguments to the alloc operator. 535 rewriter.replaceOp(op, NewCallParams(rewriter, loc) 536 .genBuffers(stt, dimSizes) 537 .genNewCall(Action::kEmpty)); 538 return success(); 539 } 540 }; 541 542 /// Sparse conversion rule for the convert operator. 543 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> { 544 public: 545 using OpConversionPattern::OpConversionPattern; 546 SparseTensorConvertConverter(MLIRContext *context, 547 SparseTensorConversionOptions o) 548 : OpConversionPattern<ConvertOp>(context), options(o) {} 549 SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context, 550 SparseTensorConversionOptions o) 551 : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {} 552 553 LogicalResult 554 matchAndRewrite(ConvertOp op, OpAdaptor adaptor, 555 ConversionPatternRewriter &rewriter) const override { 556 const Location loc = op->getLoc(); 557 const auto srcTp = getSparseTensorType(op.getSource()); 558 const auto dstTp = getSparseTensorType(op); 559 if (!srcTp.hasEncoding() && !dstTp.hasEncoding()) 560 return failure(); 561 562 const Dimension dimRank = srcTp.getDimRank(); 563 const Type elemTp = srcTp.getElementType(); 564 const Value src = adaptor.getOperands()[0]; 565 if (srcTp.hasEncoding() && dstTp.hasEncoding()) { 566 const auto srcEnc = srcTp.getEncoding(); 567 const auto dstEnc = dstTp.getEncoding(); 568 // This is a sparse => sparse conversion, which is handled as follows: 569 // t = src->toCOO(); ; src to COO in dst order 570 // dst = newSparseTensor(t) 571 // Using the coordinate scheme as an intermediate does not always 572 // yield the fastest conversion but avoids the need for a full 573 // O(N^2) conversion matrix. 574 if (dstEnc == srcEnc) { 575 rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast 576 return success(); 577 } 578 NewCallParams params(rewriter, loc); 579 SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src); 580 bool useDirectConversion; 581 switch (options.sparseToSparseStrategy) { 582 case SparseToSparseConversionStrategy::kViaCOO: 583 useDirectConversion = false; 584 break; 585 case SparseToSparseConversionStrategy::kDirect: 586 useDirectConversion = true; 587 assert(canUseDirectConversion(dstEnc.getLvlTypes()) && 588 "Unsupported target for direct sparse-to-sparse conversion"); 589 break; 590 case SparseToSparseConversionStrategy::kAuto: 591 useDirectConversion = canUseDirectConversion(dstEnc.getLvlTypes()); 592 break; 593 } 594 if (useDirectConversion) { 595 rewriter.replaceOp( 596 op, params.genBuffers(srcTp.withEncoding(dstEnc), dimSizes) 597 .genNewCall(Action::kSparseToSparse, src)); 598 } else { // use via-COO conversion. 599 // Set up encoding with right mix of src and dst so that the two 600 // method calls can share most parameters, while still providing 601 // the correct sparsity information to either of them. 602 const auto mixedEnc = 603 dstEnc.withBitWidths(srcEnc.getPosWidth(), srcEnc.getCrdWidth()); 604 // TODO: This is the only place where `kToCOO` (or `kToIterator`) 605 // is called with a non-identity permutation. Is there any clean 606 // way to push the permutation over to the `kFromCOO` side instead? 607 Value coo = params.genBuffers(srcTp.withEncoding(mixedEnc), dimSizes) 608 .genNewCall(Action::kToCOO, src); 609 Value dst = params.setTemplateTypes(srcTp.withEncoding(dstEnc)) 610 .genNewCall(Action::kFromCOO, coo); 611 genDelCOOCall(rewriter, loc, elemTp, coo); 612 rewriter.replaceOp(op, dst); 613 } 614 return success(); 615 } 616 if (srcTp.hasEncoding() && !dstTp.hasEncoding()) { 617 const auto srcEnc = srcTp.getEncoding(); 618 // This is sparse => dense conversion, which is handled as follows: 619 // dst = new Tensor(0); 620 // iter = new SparseTensorIterator(src); 621 // while (elem = iter->getNext()) { 622 // dst[elem.coords] = elem.value; 623 // } 624 // delete iter; 625 // 626 // Fabricate a no-permutation encoding for NewCallParams 627 // The position/coordinate types must be those of `src`. 628 // The dimLevelTypes aren't actually used by Action::kToIterator. 629 const auto dstEnc = SparseTensorEncodingAttr::get( 630 op->getContext(), 631 SmallVector<DimLevelType>(dimRank, DimLevelType::Dense), AffineMap(), 632 AffineMap(), srcEnc.getPosWidth(), srcEnc.getCrdWidth()); 633 SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src); 634 Value iter = NewCallParams(rewriter, loc) 635 .genBuffers(dstTp.withEncoding(dstEnc), dimSizes) 636 .genNewCall(Action::kToIterator, src); 637 const Type iTp = rewriter.getIndexType(); 638 Value dimCoords = genAlloca(rewriter, loc, dimRank, iTp); 639 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 640 // TODO: Dense buffers should be allocated/deallocated via the callback 641 // in BufferizationOptions. 642 Value dst = allocDenseTensor(rewriter, loc, dstTp, dimSizes); 643 const SmallVector<Value> noArgs; 644 const SmallVector<Type> noTypes; 645 auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs); 646 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes); 647 rewriter.setInsertionPointToEnd(before); 648 Value cond = genGetNextCall(rewriter, loc, iter, dimCoords, elemPtr); 649 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 650 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); 651 rewriter.setInsertionPointToStart(after); 652 const auto dcvs = loadAll(rewriter, loc, dimRank, dimCoords); 653 insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, dcvs); 654 rewriter.create<scf::YieldOp>(loc); 655 rewriter.setInsertionPointAfter(whileOp); 656 genDelIteratorCall(rewriter, loc, elemTp, iter); 657 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>( 658 op, dstTp.getRankedTensorType(), dst); 659 return success(); 660 } 661 assert(!srcTp.hasEncoding() && dstTp.hasEncoding()); 662 // This is a dense => sparse conversion or a sparse constant in COO => 663 // sparse conversion, which is handled as follows: 664 // t = newSparseCOO() 665 // ...code to fill the COO tensor t... 666 // s = newSparseTensor(t) 667 // 668 // To fill the COO tensor from a dense tensor: 669 // for i1 in dim1 670 // .. 671 // for ik in dimk 672 // val = a[i1,..,ik] 673 // if val != 0 674 // t->add(val, [i1,..,ik], [p1,..,pk]) 675 // 676 // To fill the COO tensor from a sparse constant in COO format: 677 // for i in range(NNZ) 678 // val = values[i] 679 // [i1,..,ik] = coordinates[i] 680 // t->add(val, [i1,..,ik], [p1,..,pk]) 681 // 682 // Note that the dense tensor traversal code is actually implemented 683 // using MLIR IR to avoid having to expose too much low-level 684 // memref traversal details to the runtime support library. 685 // Also note that the code below only generates the "new" ops and 686 // the loop-nest per se; whereas the entire body of the innermost 687 // loop is generated by genAddElt(). 688 SmallVector<Value> dimSizes; 689 sizesFromSrc(rewriter, dimSizes, loc, src); 690 NewCallParams params(rewriter, loc); 691 Value coo = 692 params.genBuffers(dstTp, dimSizes).genNewCall(Action::kEmptyCOO); 693 const Type iTp = rewriter.getIndexType(); 694 Value dimCoords = genAlloca(rewriter, loc, dimRank, iTp); 695 Value dimToLvl = params.getDimToLvl(); 696 Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); 697 genDenseTensorOrSparseConstantIterLoop( 698 rewriter, loc, src, dimRank, 699 [&](OpBuilder &builder, Location loc, Value val, ValueRange dcvs) { 700 assert(dcvs.size() == static_cast<size_t>(dimRank)); 701 storeAll(builder, loc, dimCoords, dcvs); 702 builder.create<memref::StoreOp>(loc, val, elemPtr); 703 genAddEltCall(builder, loc, elemTp, coo, elemPtr, dimCoords, 704 dimToLvl); 705 }); 706 // Final call to construct sparse tensor storage. 707 Value dst = params.genNewCall(Action::kFromCOO, coo); 708 genDelCOOCall(rewriter, loc, elemTp, coo); 709 rewriter.replaceOp(op, dst); 710 return success(); 711 } 712 713 private: 714 /// Options to control sparse code generation. 715 SparseTensorConversionOptions options; 716 }; 717 718 /// Sparse conversion rule for the dealloc operator. 719 class SparseTensorDeallocConverter 720 : public OpConversionPattern<bufferization::DeallocTensorOp> { 721 public: 722 using OpConversionPattern::OpConversionPattern; 723 LogicalResult 724 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, 725 ConversionPatternRewriter &rewriter) const override { 726 if (!getSparseTensorType(op.getTensor()).hasEncoding()) 727 return failure(); 728 StringRef name = "delSparseTensor"; 729 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 730 EmitCInterface::Off); 731 rewriter.eraseOp(op); 732 return success(); 733 } 734 }; 735 736 /// Sparse conversion rule for position accesses. 737 class SparseTensorToPositionsConverter 738 : public OpConversionPattern<ToPositionsOp> { 739 public: 740 using OpConversionPattern::OpConversionPattern; 741 LogicalResult 742 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor, 743 ConversionPatternRewriter &rewriter) const override { 744 Type resTp = op.getType(); 745 Type posTp = cast<ShapedType>(resTp).getElementType(); 746 SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)}; 747 Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel()); 748 replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl}, 749 EmitCInterface::On); 750 return success(); 751 } 752 }; 753 754 /// Sparse conversion rule for coordinate accesses. 755 class SparseTensorToCoordinatesConverter 756 : public OpConversionPattern<ToCoordinatesOp> { 757 public: 758 using OpConversionPattern::OpConversionPattern; 759 LogicalResult 760 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor, 761 ConversionPatternRewriter &rewriter) const override { 762 // TODO: use `SparseTensorType::getCrdType` instead. 763 Type resType = op.getType(); 764 const Type crdTp = cast<ShapedType>(resType).getElementType(); 765 SmallString<19> name{"sparseCoordinates", 766 overheadTypeFunctionSuffix(crdTp)}; 767 Location loc = op->getLoc(); 768 Value lvl = constantIndex(rewriter, loc, op.getLevel()); 769 770 // The function returns a MemRef without a layout. 771 MemRefType callRetType = get1DMemRefType(crdTp, false); 772 SmallVector<Value> operands{adaptor.getTensor(), lvl}; 773 auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType, 774 operands, EmitCInterface::On); 775 Value callRet = 776 rewriter.create<func::CallOp>(loc, callRetType, fn, operands) 777 .getResult(0); 778 779 // Cast the MemRef type to the type expected by the users, though these 780 // two types should be compatible at runtime. 781 if (resType != callRetType) 782 callRet = rewriter.create<memref::CastOp>(loc, resType, callRet); 783 rewriter.replaceOp(op, callRet); 784 785 return success(); 786 } 787 }; 788 789 /// Sparse conversion rule for value accesses. 790 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> { 791 public: 792 using OpConversionPattern::OpConversionPattern; 793 LogicalResult 794 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, 795 ConversionPatternRewriter &rewriter) const override { 796 auto resType = cast<ShapedType>(op.getType()); 797 rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType, 798 adaptor.getOperands())); 799 return success(); 800 } 801 }; 802 803 /// Sparse conversion rule for number of entries operator. 804 class SparseNumberOfEntriesConverter 805 : public OpConversionPattern<NumberOfEntriesOp> { 806 public: 807 using OpConversionPattern::OpConversionPattern; 808 LogicalResult 809 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, 810 ConversionPatternRewriter &rewriter) const override { 811 Location loc = op.getLoc(); 812 // Query values array size for the actually stored values size. 813 Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType(); 814 auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType); 815 Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands()); 816 rewriter.replaceOpWithNewOp<memref::DimOp>(op, values, 817 constantIndex(rewriter, loc, 0)); 818 return success(); 819 } 820 }; 821 822 /// Sparse conversion rule for tensor rematerialization. 823 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 824 public: 825 using OpConversionPattern::OpConversionPattern; 826 LogicalResult 827 matchAndRewrite(LoadOp op, OpAdaptor adaptor, 828 ConversionPatternRewriter &rewriter) const override { 829 if (op.getHasInserts()) { 830 // Finalize any pending insertions. 831 StringRef name = "endInsert"; 832 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), 833 EmitCInterface::Off); 834 } 835 rewriter.replaceOp(op, adaptor.getOperands()); 836 return success(); 837 } 838 }; 839 840 /// Sparse conversion rule for the insertion operator. 841 class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> { 842 public: 843 using OpConversionPattern::OpConversionPattern; 844 LogicalResult 845 matchAndRewrite(InsertOp op, OpAdaptor adaptor, 846 ConversionPatternRewriter &rewriter) const override { 847 // Note that the current regime only allows for strict lexicographic 848 // coordinate order. All values are passed by reference through stack 849 // allocated memrefs. 850 Location loc = op->getLoc(); 851 const auto stt = getSparseTensorType(op.getTensor()); 852 const auto elemTp = stt.getElementType(); 853 const Level lvlRank = stt.getLvlRank(); 854 auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 855 auto vref = genAllocaScalar(rewriter, loc, elemTp); 856 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords()); 857 rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref); 858 SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; 859 createFuncCall(rewriter, loc, name, {}, 860 {adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On); 861 rewriter.replaceOp(op, adaptor.getTensor()); 862 return success(); 863 } 864 }; 865 866 /// Sparse conversion rule for the expand operator. 867 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> { 868 public: 869 using OpConversionPattern::OpConversionPattern; 870 LogicalResult 871 matchAndRewrite(ExpandOp op, OpAdaptor adaptor, 872 ConversionPatternRewriter &rewriter) const override { 873 Location loc = op->getLoc(); 874 const auto srcTp = getSparseTensorType(op.getTensor()); 875 Type eltType = srcTp.getElementType(); 876 Type boolType = rewriter.getIntegerType(1); 877 Type idxType = rewriter.getIndexType(); 878 // All initialization should be done on entry of the loop nest. 879 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); 880 // Get the cardinality of valid coordinates for the innermost level. 881 Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(), 882 srcTp.getLvlRank() - 1); 883 // Allocate temporary buffers for values, filled-switch, and coordinates. 884 // We do not use stack buffers for this, since the expanded size may 885 // be rather large (as it envelops a single expanded dense dimension). 886 Value values = genAlloc(rewriter, loc, sz, eltType); 887 Value filled = genAlloc(rewriter, loc, sz, boolType); 888 Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType); 889 Value zero = constantZero(rewriter, loc, idxType); 890 // Reset the values/filled-switch to all-zero/false. Note that this 891 // introduces an O(N) operation into the computation, but this reset 892 // operation is amortized over the innermost loops for the access 893 // pattern expansion. As noted in the operation doc, we would like 894 // to amortize this setup cost even between kernels. 895 rewriter.create<linalg::FillOp>( 896 loc, ValueRange{constantZero(rewriter, loc, eltType)}, 897 ValueRange{values}); 898 rewriter.create<linalg::FillOp>( 899 loc, ValueRange{constantZero(rewriter, loc, boolType)}, 900 ValueRange{filled}); 901 // Replace expansion op with these buffers and initial coordinate. 902 assert(op.getNumResults() == 4); 903 rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero}); 904 return success(); 905 } 906 }; 907 908 /// Sparse conversion rule for the compress operator. 909 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { 910 public: 911 using OpConversionPattern::OpConversionPattern; 912 LogicalResult 913 matchAndRewrite(CompressOp op, OpAdaptor adaptor, 914 ConversionPatternRewriter &rewriter) const override { 915 Location loc = op->getLoc(); 916 // Note that this method call resets the values/filled-switch back to 917 // all-zero/false by only iterating over the set elements, so the 918 // complexity remains proportional to the sparsity of the expanded 919 // access pattern. 920 Value values = adaptor.getValues(); 921 Value filled = adaptor.getFilled(); 922 Value added = adaptor.getAdded(); 923 Value count = adaptor.getCount(); 924 Value tensor = adaptor.getTensor(); 925 const auto stt = getSparseTensorType(op.getTensor()); 926 const Type elemTp = stt.getElementType(); 927 const Level lvlRank = stt.getLvlRank(); 928 auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType()); 929 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords()); 930 SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; 931 createFuncCall(rewriter, loc, name, {}, 932 {tensor, lvlCoords, values, filled, added, count}, 933 EmitCInterface::On); 934 rewriter.replaceOp(op, adaptor.getTensor()); 935 // Deallocate the buffers on exit of the loop nest. 936 Operation *parent = getTop(op); 937 rewriter.setInsertionPointAfter(parent); 938 rewriter.create<memref::DeallocOp>(loc, values); 939 rewriter.create<memref::DeallocOp>(loc, filled); 940 rewriter.create<memref::DeallocOp>(loc, added); 941 return success(); 942 } 943 }; 944 945 /// Sparse conversion rule for the output operator. 946 class SparseTensorOutConverter : public OpConversionPattern<OutOp> { 947 public: 948 using OpConversionPattern::OpConversionPattern; 949 LogicalResult 950 matchAndRewrite(OutOp op, OpAdaptor adaptor, 951 ConversionPatternRewriter &rewriter) const override { 952 const Location loc = op->getLoc(); 953 const auto srcTp = getSparseTensorType(op.getTensor()); 954 // Convert to default permuted COO. 955 Value src = adaptor.getOperands()[0]; 956 SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src); 957 Value coo = NewCallParams(rewriter, loc) 958 .genBuffers(srcTp.withoutDimToLvl(), dimSizes) 959 .genNewCall(Action::kToCOO, src); 960 // Then output the tensor to external file with coordinates in the 961 // externally visible lexicographic coordinate order. A sort is 962 // required if the source was not in that order yet (note that the 963 // sort can be dropped altogether if external format does not care 964 // about the order at all, but here we assume it does). 965 const Value sort = constantI1(rewriter, loc, !srcTp.isIdentity()); 966 SmallVector<Value, 3> outParams{coo, adaptor.getOperands()[1], sort}; 967 const Type elemTp = srcTp.getElementType(); 968 SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(elemTp)}; 969 createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off); 970 genDelCOOCall(rewriter, loc, elemTp, coo); 971 rewriter.eraseOp(op); 972 return success(); 973 } 974 }; 975 976 /// Sparse conversion rule for the sparse_tensor.pack operator. 977 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> { 978 public: 979 using OpConversionPattern::OpConversionPattern; 980 LogicalResult 981 matchAndRewrite(AssembleOp op, OpAdaptor adaptor, 982 ConversionPatternRewriter &rewriter) const override { 983 const Location loc = op->getLoc(); 984 const auto dstTp = getSparseTensorType(op.getResult()); 985 // AssembleOps always returns a static shaped tensor result. 986 assert(dstTp.hasStaticDimShape()); 987 SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp); 988 Value dst = 989 NewCallParams(rewriter, loc) 990 .genBuffers(dstTp.withoutDimToLvl(), dimSizes) 991 .genNewCall(Action::kPack, 992 genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(), 993 adaptor.getValues())); 994 rewriter.replaceOp(op, dst); 995 return success(); 996 } 997 }; 998 999 } // namespace 1000 1001 //===----------------------------------------------------------------------===// 1002 // Sparse tensor type conversion into opaque pointer. 1003 //===----------------------------------------------------------------------===// 1004 1005 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() { 1006 addConversion([](Type type) { return type; }); 1007 addConversion(convertSparseTensorTypes); 1008 } 1009 1010 //===----------------------------------------------------------------------===// 1011 // Public method for populating conversion rules. 1012 //===----------------------------------------------------------------------===// 1013 1014 /// Populates the given patterns list with conversion rules required for 1015 /// the sparsification of linear algebra operations. 1016 void mlir::populateSparseTensorConversionPatterns( 1017 TypeConverter &typeConverter, RewritePatternSet &patterns, 1018 const SparseTensorConversionOptions &options) { 1019 patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter, 1020 SparseCastConverter, SparseTensorNewConverter, 1021 SparseTensorAllocConverter, SparseTensorEmptyConverter, 1022 SparseTensorDeallocConverter, SparseTensorToPositionsConverter, 1023 SparseTensorToCoordinatesConverter, 1024 SparseTensorToValuesConverter, SparseNumberOfEntriesConverter, 1025 SparseTensorLoadConverter, SparseTensorInsertConverter, 1026 SparseTensorExpandConverter, SparseTensorCompressConverter, 1027 SparseTensorOutConverter, SparseTensorAssembleConverter>( 1028 typeConverter, patterns.getContext()); 1029 patterns.add<SparseTensorConvertConverter>(typeConverter, 1030 patterns.getContext(), options); 1031 } 1032