xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (revision d5622decf16b435792e90a06e8d2f17b0f572760)
1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor primitives into calls into a runtime
10 // support library. Sparse tensor types are converted into opaque pointers
11 // to the underlying sparse storage schemes. The use of opaque pointers
12 // together with runtime support library keeps the conversion relatively
13 // simple, but at the expense of IR opacity, which obscures opportunities
14 // for subsequent optimization of the IR. An alternative is provided by
15 // the SparseTensorCodegen pass.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "CodegenUtils.h"
20 
21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
30 #include "mlir/Dialect/Tensor/IR/Tensor.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 
33 using namespace mlir;
34 using namespace mlir::sparse_tensor;
35 
36 namespace {
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Maps each sparse tensor type to an opaque pointer.
43 static std::optional<Type> convertSparseTensorTypes(Type type) {
44   if (getSparseTensorEncoding(type) != nullptr)
45     return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
46   return std::nullopt;
47 }
48 
49 /// Replaces the `op` with a `CallOp` to the `getFunc()` function reference.
50 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
51                                           StringRef name, TypeRange resultType,
52                                           ValueRange operands,
53                                           EmitCInterface emitCInterface) {
54   auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands,
55                     emitCInterface);
56   return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
57                                                    operands);
58 }
59 
60 /// Generates call to lookup a level-size.  N.B., this only generates
61 /// the raw function call, and therefore (intentionally) does not perform
62 /// any dim<->lvl conversion or other logic.
63 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
64                             uint64_t lvl) {
65   StringRef name = "sparseLvlSize";
66   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
67   Type iTp = builder.getIndexType();
68   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
69       .getResult(0);
70 }
71 
72 /// Generates call to lookup a dimension-size.  N.B., this only generates
73 /// the raw function call, and therefore (intentionally) does not perform
74 /// any dim<->lvl conversion or other logic.
75 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
76                             uint64_t dim) {
77   StringRef name = "sparseDimSize";
78   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
79   Type iTp = builder.getIndexType();
80   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
81       .getResult(0);
82 }
83 
84 /// Looks up a level-size by returning a statically-computed constant
85 /// (when possible), or by calling `genLvlSizeCall` (when dynamic).
86 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
87                                  SparseTensorType stt, Value tensor,
88                                  Level lvl) {
89   // Only sparse tensors have "levels" to query.
90   assert(stt.hasEncoding());
91   // TODO: The following implementation only handles permutations;
92   // we'll need to generalize this to handle arbitrary AffineExpr.
93   //
94   // There's no need to assert `isPermutation` here: because
95   // `getDimPosition` checks that the expr isa `AffineDimExpr`,
96   // which is all we care about (for supporting permutations).
97   const Dimension dim =
98       stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
99   if (const auto sz = stt.getStaticDimSize(dim))
100     return constantIndex(builder, loc, *sz);
101   // If we cannot statically compute the size from the shape, then we
102   // must dynamically query it.  (In principle we could also dynamically
103   // compute it, but since we already did so to construct the `tensor`
104   // in the first place, we might as well query rather than recompute.)
105   return genLvlSizeCall(builder, loc, tensor, lvl);
106 }
107 
108 /// Looks up a dimension-size by returning a constant from the shape
109 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
110 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
111 /// of dense tensors).
112 static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
113                                  SparseTensorType stt, Value tensor,
114                                  Dimension dim) {
115   if (const auto sz = stt.getStaticDimSize(dim))
116     return constantIndex(builder, loc, *sz);
117   if (stt.hasEncoding())
118     return genDimSizeCall(builder, loc, tensor, dim);
119   return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
120 }
121 
122 /// Populates the array with the dimension-sizes of the given tensor.
123 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
124                          Value tensor, SmallVectorImpl<Value> &out) {
125   const Dimension dimRank = stt.getDimRank();
126   out.clear();
127   out.reserve(dimRank);
128   for (Dimension d = 0; d < dimRank; d++)
129     out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
130 }
131 
132 /// Returns an array with the dimension-sizes of the given tensor.
133 /// If the *tensor* parameters is null, the tensor type is assumed to have a
134 /// static shape.
135 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
136                                       SparseTensorType stt,
137                                       Value tensor = Value()) {
138   SmallVector<Value> out;
139   fillDimSizes(builder, loc, stt, tensor, out);
140   return out;
141 }
142 
143 /// Generates an uninitialized buffer of the given size and type,
144 /// but returns it as type `memref<? x $tp>` (rather than as type
145 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
146 /// this buffer must be explicitly deallocated by client.
147 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
148   auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
149   return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
150 }
151 
152 /// Generates a temporary buffer for the level-types of the given encoding.
153 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
154                                SparseTensorType stt) {
155   SmallVector<Value> lvlTypes;
156   lvlTypes.reserve(stt.getLvlRank());
157   for (const auto dlt : stt.getEncoding().getLvlTypes())
158     lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
159   return allocaBuffer(builder, loc, lvlTypes);
160 }
161 
162 /// Extracts the bare (aligned) pointers that point to the tensor.
163 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
164                                       Value tensor) {
165   auto buf = genToMemref(builder, loc, tensor);
166   return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf);
167 }
168 
169 /// Generates a temporary buffer for the level-types of the given encoding.
170 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
171                                ValueRange lvlTensors, Value valTensor) {
172   SmallVector<Value> lvlBarePtrs;
173   lvlBarePtrs.reserve(lvlTensors.size() + 1);
174   // Passing in lvl buffer pointers.
175   for (const auto lvl : lvlTensors)
176     lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
177 
178   // Passing in value buffer pointers.
179   lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
180   Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>(
181       loc, allocaBuffer(builder, loc, lvlBarePtrs));
182   Value idxCast =
183       builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr);
184   return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder),
185                                           idxCast);
186 }
187 
188 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
189 /// the "swiss army knife" method of the sparse runtime support library
190 /// for materializing sparse tensors into the computation. This abstraction
191 /// reduces the need for modifications when the API changes.
192 class NewCallParams final {
193 public:
194   /// Allocates the `ValueRange` for the `func::CallOp` parameters.
195   NewCallParams(OpBuilder &builder, Location loc)
196       : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
197 
198   /// Initializes all static parameters (i.e., those which indicate
199   /// type-level information such as the encoding and sizes), generating
200   /// MLIR buffers as needed, and returning `this` for method chaining.
201   NewCallParams &genBuffers(SparseTensorType stt,
202                             ArrayRef<Value> dimSizesValues) {
203     assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
204     // Sparsity annotations.
205     params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
206     // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
207     params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizesValues);
208     params[kParamLvlSizes] =
209         genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes],
210                       params[kParamDim2Lvl], params[kParamLvl2Dim]);
211     // Secondary and primary types encoding.
212     setTemplateTypes(stt);
213     // Finally, make note that initialization is complete.
214     assert(isInitialized() && "Initialization failed");
215     // And return `this` for method chaining.
216     return *this;
217   }
218 
219   /// (Re)sets the C++ template type parameters, and returns `this`
220   /// for method chaining. This is already done as part of `genBuffers`,
221   /// but is factored out so that it can also be called independently
222   /// whenever subsequent `genNewCall` calls want to reuse the same
223   /// buffers but different type parameters.
224   //
225   // TODO: This is only ever used by sparse2sparse-viaCOO `ConvertOp`;
226   // is there a better way to handle that than this one-off setter method?
227   NewCallParams &setTemplateTypes(SparseTensorType stt) {
228     const auto enc = stt.getEncoding();
229     params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
230     params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc);
231     params[kParamValTp] =
232         constantPrimaryTypeEncoding(builder, loc, stt.getElementType());
233     return *this;
234   }
235 
236   /// Checks whether all the static parameters have been initialized.
237   bool isInitialized() const {
238     for (unsigned i = 0; i < kNumStaticParams; ++i)
239       if (!params[i])
240         return false;
241     return true;
242   }
243 
244   /// Gets the dimension-to-level mapping.
245   //
246   // TODO: This is only ever used for passing into `genAddEltCall`;
247   // is there a better way to encapsulate that pattern (both to avoid
248   // this one-off getter, and to avoid potential mixups)?
249   Value getDimToLvl() const {
250     assert(isInitialized() && "Must initialize before getDimToLvl");
251     return params[kParamDim2Lvl];
252   }
253 
254   /// Generates a function call, with the current static parameters
255   /// and the given dynamic arguments.
256   Value genNewCall(Action action, Value ptr = Value()) {
257     assert(isInitialized() && "Must initialize before genNewCall");
258     StringRef name = "newSparseTensor";
259     params[kParamAction] = constantAction(builder, loc, action);
260     params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp);
261     return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
262         .getResult(0);
263   }
264 
265 private:
266   static constexpr unsigned kNumStaticParams = 8;
267   static constexpr unsigned kNumDynamicParams = 2;
268   static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
269   static constexpr unsigned kParamDimSizes = 0;
270   static constexpr unsigned kParamLvlSizes = 1;
271   static constexpr unsigned kParamLvlTypes = 2;
272   static constexpr unsigned kParamDim2Lvl = 3;
273   static constexpr unsigned kParamLvl2Dim = 4;
274   static constexpr unsigned kParamPosTp = 5;
275   static constexpr unsigned kParamCrdTp = 6;
276   static constexpr unsigned kParamValTp = 7;
277   static constexpr unsigned kParamAction = 8;
278   static constexpr unsigned kParamPtr = 9;
279 
280   OpBuilder &builder;
281   Location loc;
282   Type pTp;
283   Value params[kNumParams];
284 };
285 
286 /// Generates a call to obtain the values array.
287 static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
288                            ValueRange ptr) {
289   SmallString<15> name{"sparseValues",
290                        primaryTypeFunctionSuffix(tp.getElementType())};
291   return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
292       .getResult(0);
293 }
294 
295 /// Generates a call to release/delete a `SparseTensorCOO`.
296 static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp,
297                           Value coo) {
298   SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)};
299   createFuncCall(builder, loc, name, {}, coo, EmitCInterface::Off);
300 }
301 
302 /// Generates a call to release/delete a `SparseTensorIterator`.
303 static void genDelIteratorCall(OpBuilder &builder, Location loc, Type elemTp,
304                                Value iter) {
305   SmallString<26> name{"delSparseTensorIterator",
306                        primaryTypeFunctionSuffix(elemTp)};
307   createFuncCall(builder, loc, name, {}, iter, EmitCInterface::Off);
308 }
309 
310 /// Generates a call that adds one element to a coordinate scheme.
311 /// In particular, this generates code like the following:
312 ///   val = a[i1,..,ik];
313 ///   if val != 0
314 ///     t->add(&val, [i1,..,ik], [p1,..,pk]);
315 static void genAddEltCall(OpBuilder &builder, Location loc, Type eltType,
316                           Value lvlCOO, Value valPtr, Value dimCoords,
317                           Value dimToLvl) {
318   SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)};
319   SmallVector<Value, 4> params{lvlCOO, valPtr, dimCoords, dimToLvl};
320   Type pTp = getOpaquePointerType(builder);
321   createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On);
322 }
323 
324 /// Generates a call to `iter->getNext()`.  If there is a next element,
325 /// then it is copied into the out-parameters `coords` and `elemPtr`,
326 /// and the return value is true.  If there isn't a next element, then
327 /// the return value is false.
328 ///
329 /// The `coords` argument uses the same coordinate-space as the `iter`
330 /// (which can be either dim- or lvl-coords, depending on context).
331 static Value genGetNextCall(OpBuilder &builder, Location loc, Value iter,
332                             Value coords, Value elemPtr) {
333   Type elemTp = cast<ShapedType>(elemPtr.getType()).getElementType();
334   SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)};
335   SmallVector<Value, 3> params{iter, coords, elemPtr};
336   Type i1 = builder.getI1Type();
337   return createFuncCall(builder, loc, name, i1, params, EmitCInterface::On)
338       .getResult(0);
339 }
340 
341 /// Loads the value stored in `elemPtr`, and stores it at the coordinates
342 /// `cvs` into a dense tensor created by `allocDenseTensor`.
343 static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc,
344                                         Value elemPtr, Value tensor,
345                                         ValueRange cvs) {
346   Value elemV = builder.create<memref::LoadOp>(loc, elemPtr);
347   builder.create<memref::StoreOp>(loc, elemV, tensor, cvs);
348 }
349 
350 /// Determine if the runtime library supports direct conversion to the
351 /// given target `dimTypes`.
352 static bool canUseDirectConversion(ArrayRef<DimLevelType> dimTypes) {
353   bool alreadyCompressed = false;
354   for (const auto dlt : dimTypes) {
355     if (isCompressedDLT(dlt)) {
356       if (alreadyCompressed)
357         return false; // Multiple compressed dimensions not yet supported.
358       alreadyCompressed = true;
359     } else if (isDenseDLT(dlt)) {
360       if (alreadyCompressed)
361         return false; // Dense after Compressed not yet supported.
362     } else if (isSingletonDLT(dlt)) {
363       // Direct conversion doesn't have any particular problems with
364       // singleton after compressed.
365     } else { // TODO: investigate
366       return false;
367     }
368   }
369   return true;
370 }
371 
372 //===----------------------------------------------------------------------===//
373 // Conversion rules.
374 //===----------------------------------------------------------------------===//
375 
376 /// Sparse conversion rule for returns.
377 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
378 public:
379   using OpConversionPattern::OpConversionPattern;
380   LogicalResult
381   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
382                   ConversionPatternRewriter &rewriter) const override {
383     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
384     return success();
385   }
386 };
387 
388 /// Sparse conversion rule for accessing dimension-sizes.
389 class SparseTensorToDimSizeConverter
390     : public OpConversionPattern<tensor::DimOp> {
391 public:
392   using OpConversionPattern::OpConversionPattern;
393   LogicalResult
394   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
395                   ConversionPatternRewriter &rewriter) const override {
396     const auto stt = getSparseTensorType(op.getSource());
397     // Only rewrite sparse DimOp.
398     if (!stt.hasEncoding())
399       return failure();
400     // Only rewrite DimOp with constant index.
401     std::optional<int64_t> dim = op.getConstantIndex();
402     if (!dim)
403       return failure();
404     // Generate the call.
405     Value src = adaptor.getOperands()[0];
406     rewriter.replaceOp(
407         op, createOrFoldDimCall(rewriter, op->getLoc(), stt, src, *dim));
408     return success();
409   }
410 };
411 
412 /// Sparse conversion rule for trivial tensor casts.
413 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
414 public:
415   using OpConversionPattern::OpConversionPattern;
416   LogicalResult
417   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
418                   ConversionPatternRewriter &rewriter) const override {
419     // Only rewrite identically annotated source/dest.
420     auto encDst = getSparseTensorEncoding(op.getType());
421     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
422     if (!encDst || encDst != encSrc)
423       return failure();
424     rewriter.replaceOp(op, adaptor.getOperands());
425     return success();
426   }
427 };
428 
429 /// Sparse conversion rule for the new operator.
430 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
431 public:
432   using OpConversionPattern::OpConversionPattern;
433   LogicalResult
434   matchAndRewrite(NewOp op, OpAdaptor adaptor,
435                   ConversionPatternRewriter &rewriter) const override {
436     Location loc = op.getLoc();
437     const auto stt = getSparseTensorType(op);
438     if (!stt.hasEncoding())
439       return failure();
440     // Construct the reader opening method calls.
441     SmallVector<Value> dimShapesValues;
442     Value dimSizesBuffer;
443     Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
444                              dimShapesValues, dimSizesBuffer);
445     // Now construct the lvlSizes, dim2lvl, and lvl2dim buffers.
446     Value dim2lvlBuffer;
447     Value lvl2dimBuffer;
448     Value lvlSizesBuffer =
449         genMapBuffers(rewriter, loc, stt, dimShapesValues, dimSizesBuffer,
450                       dim2lvlBuffer, lvl2dimBuffer);
451     // Use the `reader` to parse the file.
452     Type opaqueTp = getOpaquePointerType(rewriter);
453     Type eltTp = stt.getElementType();
454     Value valTp = constantPrimaryTypeEncoding(rewriter, loc, eltTp);
455     SmallVector<Value, 8> params{
456         reader,
457         lvlSizesBuffer,
458         genLvlTypesBuffer(rewriter, loc, stt),
459         dim2lvlBuffer,
460         lvl2dimBuffer,
461         constantPosTypeEncoding(rewriter, loc, stt.getEncoding()),
462         constantCrdTypeEncoding(rewriter, loc, stt.getEncoding()),
463         valTp};
464     Value tensor = createFuncCall(rewriter, loc, "newSparseTensorFromReader",
465                                   opaqueTp, params, EmitCInterface::On)
466                        .getResult(0);
467     // Free the memory for `reader`.
468     createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
469                    EmitCInterface::Off);
470     rewriter.replaceOp(op, tensor);
471     return success();
472   }
473 };
474 
475 /// Sparse conversion rule for the alloc operator.
476 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
477 class SparseTensorAllocConverter
478     : public OpConversionPattern<bufferization::AllocTensorOp> {
479 public:
480   using OpConversionPattern::OpConversionPattern;
481   LogicalResult
482   matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
483                   ConversionPatternRewriter &rewriter) const override {
484     if (op.getCopy())
485       return rewriter.notifyMatchFailure(op,
486                                          "sparse tensor copy not implemented");
487     Location loc = op.getLoc();
488     const auto stt = getSparseTensorType(op);
489     if (!stt.hasEncoding())
490       return failure();
491     // Gather all dimension sizes as SSA values.
492     const Dimension dimRank = stt.getDimRank();
493     SmallVector<Value> dimSizes;
494     dimSizes.reserve(dimRank);
495     unsigned operandCtr = 0;
496     for (Dimension d = 0; d < dimRank; ++d) {
497       dimSizes.push_back(
498           stt.isDynamicDim(d)
499               ? adaptor.getOperands()[operandCtr++]
500               : constantIndex(rewriter, loc, op.getStaticSize(d)));
501     }
502     // Generate the call to construct empty tensor. The sizes are
503     // explicitly defined by the arguments to the alloc operator.
504     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
505                                .genBuffers(stt, dimSizes)
506                                .genNewCall(Action::kEmpty));
507     return success();
508   }
509 };
510 
511 /// Sparse conversion rule for the empty tensor.
512 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
513 public:
514   using OpConversionPattern::OpConversionPattern;
515   LogicalResult
516   matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
517                   ConversionPatternRewriter &rewriter) const override {
518     Location loc = op.getLoc();
519     const auto stt = getSparseTensorType(op);
520     if (!stt.hasEncoding())
521       return failure();
522     // Gather all dimension sizes as SSA values.
523     const Dimension dimRank = stt.getDimRank();
524     SmallVector<Value> dimSizes;
525     dimSizes.reserve(dimRank);
526     auto shape = op.getType().getShape();
527     unsigned operandCtr = 0;
528     for (Dimension d = 0; d < dimRank; ++d) {
529       dimSizes.push_back(stt.isDynamicDim(d)
530                              ? adaptor.getOperands()[operandCtr++]
531                              : constantIndex(rewriter, loc, shape[d]));
532     }
533     // Generate the call to construct empty tensor. The sizes are
534     // explicitly defined by the arguments to the alloc operator.
535     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
536                                .genBuffers(stt, dimSizes)
537                                .genNewCall(Action::kEmpty));
538     return success();
539   }
540 };
541 
542 /// Sparse conversion rule for the convert operator.
543 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
544 public:
545   using OpConversionPattern::OpConversionPattern;
546   SparseTensorConvertConverter(MLIRContext *context,
547                                SparseTensorConversionOptions o)
548       : OpConversionPattern<ConvertOp>(context), options(o) {}
549   SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context,
550                                SparseTensorConversionOptions o)
551       : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {}
552 
553   LogicalResult
554   matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
555                   ConversionPatternRewriter &rewriter) const override {
556     const Location loc = op->getLoc();
557     const auto srcTp = getSparseTensorType(op.getSource());
558     const auto dstTp = getSparseTensorType(op);
559     if (!srcTp.hasEncoding() && !dstTp.hasEncoding())
560       return failure();
561 
562     const Dimension dimRank = srcTp.getDimRank();
563     const Type elemTp = srcTp.getElementType();
564     const Value src = adaptor.getOperands()[0];
565     if (srcTp.hasEncoding() && dstTp.hasEncoding()) {
566       const auto srcEnc = srcTp.getEncoding();
567       const auto dstEnc = dstTp.getEncoding();
568       // This is a sparse => sparse conversion, which is handled as follows:
569       //   t = src->toCOO();         ; src to COO in dst order
570       //   dst = newSparseTensor(t)
571       // Using the coordinate scheme as an intermediate does not always
572       // yield the fastest conversion but avoids the need for a full
573       // O(N^2) conversion matrix.
574       if (dstEnc == srcEnc) {
575         rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast
576         return success();
577       }
578       NewCallParams params(rewriter, loc);
579       SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
580       bool useDirectConversion;
581       switch (options.sparseToSparseStrategy) {
582       case SparseToSparseConversionStrategy::kViaCOO:
583         useDirectConversion = false;
584         break;
585       case SparseToSparseConversionStrategy::kDirect:
586         useDirectConversion = true;
587         assert(canUseDirectConversion(dstEnc.getLvlTypes()) &&
588                "Unsupported target for direct sparse-to-sparse conversion");
589         break;
590       case SparseToSparseConversionStrategy::kAuto:
591         useDirectConversion = canUseDirectConversion(dstEnc.getLvlTypes());
592         break;
593       }
594       if (useDirectConversion) {
595         rewriter.replaceOp(
596             op, params.genBuffers(srcTp.withEncoding(dstEnc), dimSizes)
597                     .genNewCall(Action::kSparseToSparse, src));
598       } else { // use via-COO conversion.
599         // Set up encoding with right mix of src and dst so that the two
600         // method calls can share most parameters, while still providing
601         // the correct sparsity information to either of them.
602         const auto mixedEnc =
603             dstEnc.withBitWidths(srcEnc.getPosWidth(), srcEnc.getCrdWidth());
604         // TODO: This is the only place where `kToCOO` (or `kToIterator`)
605         // is called with a non-identity permutation.  Is there any clean
606         // way to push the permutation over to the `kFromCOO` side instead?
607         Value coo = params.genBuffers(srcTp.withEncoding(mixedEnc), dimSizes)
608                         .genNewCall(Action::kToCOO, src);
609         Value dst = params.setTemplateTypes(srcTp.withEncoding(dstEnc))
610                         .genNewCall(Action::kFromCOO, coo);
611         genDelCOOCall(rewriter, loc, elemTp, coo);
612         rewriter.replaceOp(op, dst);
613       }
614       return success();
615     }
616     if (srcTp.hasEncoding() && !dstTp.hasEncoding()) {
617       const auto srcEnc = srcTp.getEncoding();
618       // This is sparse => dense conversion, which is handled as follows:
619       //   dst = new Tensor(0);
620       //   iter = new SparseTensorIterator(src);
621       //   while (elem = iter->getNext()) {
622       //     dst[elem.coords] = elem.value;
623       //   }
624       //   delete iter;
625       //
626       // Fabricate a no-permutation encoding for NewCallParams
627       // The position/coordinate types must be those of `src`.
628       // The dimLevelTypes aren't actually used by Action::kToIterator.
629       const auto dstEnc = SparseTensorEncodingAttr::get(
630           op->getContext(),
631           SmallVector<DimLevelType>(dimRank, DimLevelType::Dense), AffineMap(),
632           AffineMap(), srcEnc.getPosWidth(), srcEnc.getCrdWidth());
633       SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
634       Value iter = NewCallParams(rewriter, loc)
635                        .genBuffers(dstTp.withEncoding(dstEnc), dimSizes)
636                        .genNewCall(Action::kToIterator, src);
637       const Type iTp = rewriter.getIndexType();
638       Value dimCoords = genAlloca(rewriter, loc, dimRank, iTp);
639       Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
640       // TODO: Dense buffers should be allocated/deallocated via the callback
641       // in BufferizationOptions.
642       Value dst = allocDenseTensor(rewriter, loc, dstTp, dimSizes);
643       const SmallVector<Value> noArgs;
644       const SmallVector<Type> noTypes;
645       auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
646       Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
647       rewriter.setInsertionPointToEnd(before);
648       Value cond = genGetNextCall(rewriter, loc, iter, dimCoords, elemPtr);
649       rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
650       Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
651       rewriter.setInsertionPointToStart(after);
652       const auto dcvs = loadAll(rewriter, loc, dimRank, dimCoords);
653       insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, dcvs);
654       rewriter.create<scf::YieldOp>(loc);
655       rewriter.setInsertionPointAfter(whileOp);
656       genDelIteratorCall(rewriter, loc, elemTp, iter);
657       rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(
658           op, dstTp.getRankedTensorType(), dst);
659       return success();
660     }
661     assert(!srcTp.hasEncoding() && dstTp.hasEncoding());
662     // This is a dense => sparse conversion or a sparse constant in COO =>
663     // sparse conversion, which is handled as follows:
664     //   t = newSparseCOO()
665     //   ...code to fill the COO tensor t...
666     //   s = newSparseTensor(t)
667     //
668     // To fill the COO tensor from a dense tensor:
669     //   for i1 in dim1
670     //    ..
671     //     for ik in dimk
672     //       val = a[i1,..,ik]
673     //       if val != 0
674     //         t->add(val, [i1,..,ik], [p1,..,pk])
675     //
676     // To fill the COO tensor from a sparse constant in COO format:
677     //   for i in range(NNZ)
678     //     val = values[i]
679     //     [i1,..,ik] = coordinates[i]
680     //     t->add(val, [i1,..,ik], [p1,..,pk])
681     //
682     // Note that the dense tensor traversal code is actually implemented
683     // using MLIR IR to avoid having to expose too much low-level
684     // memref traversal details to the runtime support library.
685     // Also note that the code below only generates the "new" ops and
686     // the loop-nest per se; whereas the entire body of the innermost
687     // loop is generated by genAddElt().
688     SmallVector<Value> dimSizes;
689     sizesFromSrc(rewriter, dimSizes, loc, src);
690     NewCallParams params(rewriter, loc);
691     Value coo =
692         params.genBuffers(dstTp, dimSizes).genNewCall(Action::kEmptyCOO);
693     const Type iTp = rewriter.getIndexType();
694     Value dimCoords = genAlloca(rewriter, loc, dimRank, iTp);
695     Value dimToLvl = params.getDimToLvl();
696     Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
697     genDenseTensorOrSparseConstantIterLoop(
698         rewriter, loc, src, dimRank,
699         [&](OpBuilder &builder, Location loc, Value val, ValueRange dcvs) {
700           assert(dcvs.size() == static_cast<size_t>(dimRank));
701           storeAll(builder, loc, dimCoords, dcvs);
702           builder.create<memref::StoreOp>(loc, val, elemPtr);
703           genAddEltCall(builder, loc, elemTp, coo, elemPtr, dimCoords,
704                         dimToLvl);
705         });
706     // Final call to construct sparse tensor storage.
707     Value dst = params.genNewCall(Action::kFromCOO, coo);
708     genDelCOOCall(rewriter, loc, elemTp, coo);
709     rewriter.replaceOp(op, dst);
710     return success();
711   }
712 
713 private:
714   /// Options to control sparse code generation.
715   SparseTensorConversionOptions options;
716 };
717 
718 /// Sparse conversion rule for the dealloc operator.
719 class SparseTensorDeallocConverter
720     : public OpConversionPattern<bufferization::DeallocTensorOp> {
721 public:
722   using OpConversionPattern::OpConversionPattern;
723   LogicalResult
724   matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
725                   ConversionPatternRewriter &rewriter) const override {
726     if (!getSparseTensorType(op.getTensor()).hasEncoding())
727       return failure();
728     StringRef name = "delSparseTensor";
729     createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
730                    EmitCInterface::Off);
731     rewriter.eraseOp(op);
732     return success();
733   }
734 };
735 
736 /// Sparse conversion rule for position accesses.
737 class SparseTensorToPositionsConverter
738     : public OpConversionPattern<ToPositionsOp> {
739 public:
740   using OpConversionPattern::OpConversionPattern;
741   LogicalResult
742   matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
743                   ConversionPatternRewriter &rewriter) const override {
744     Type resTp = op.getType();
745     Type posTp = cast<ShapedType>(resTp).getElementType();
746     SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
747     Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel());
748     replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl},
749                           EmitCInterface::On);
750     return success();
751   }
752 };
753 
754 /// Sparse conversion rule for coordinate accesses.
755 class SparseTensorToCoordinatesConverter
756     : public OpConversionPattern<ToCoordinatesOp> {
757 public:
758   using OpConversionPattern::OpConversionPattern;
759   LogicalResult
760   matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
761                   ConversionPatternRewriter &rewriter) const override {
762     // TODO: use `SparseTensorType::getCrdType` instead.
763     Type resType = op.getType();
764     const Type crdTp = cast<ShapedType>(resType).getElementType();
765     SmallString<19> name{"sparseCoordinates",
766                          overheadTypeFunctionSuffix(crdTp)};
767     Location loc = op->getLoc();
768     Value lvl = constantIndex(rewriter, loc, op.getLevel());
769 
770     // The function returns a MemRef without a layout.
771     MemRefType callRetType = get1DMemRefType(crdTp, false);
772     SmallVector<Value> operands{adaptor.getTensor(), lvl};
773     auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType,
774                       operands, EmitCInterface::On);
775     Value callRet =
776         rewriter.create<func::CallOp>(loc, callRetType, fn, operands)
777             .getResult(0);
778 
779     // Cast the MemRef type to the type expected by the users, though these
780     // two types should be compatible at runtime.
781     if (resType != callRetType)
782       callRet = rewriter.create<memref::CastOp>(loc, resType, callRet);
783     rewriter.replaceOp(op, callRet);
784 
785     return success();
786   }
787 };
788 
789 /// Sparse conversion rule for value accesses.
790 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
791 public:
792   using OpConversionPattern::OpConversionPattern;
793   LogicalResult
794   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
795                   ConversionPatternRewriter &rewriter) const override {
796     auto resType = cast<ShapedType>(op.getType());
797     rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
798                                          adaptor.getOperands()));
799     return success();
800   }
801 };
802 
803 /// Sparse conversion rule for number of entries operator.
804 class SparseNumberOfEntriesConverter
805     : public OpConversionPattern<NumberOfEntriesOp> {
806 public:
807   using OpConversionPattern::OpConversionPattern;
808   LogicalResult
809   matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
810                   ConversionPatternRewriter &rewriter) const override {
811     Location loc = op.getLoc();
812     // Query values array size for the actually stored values size.
813     Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType();
814     auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType);
815     Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
816     rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
817                                                constantIndex(rewriter, loc, 0));
818     return success();
819   }
820 };
821 
822 /// Sparse conversion rule for tensor rematerialization.
823 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
824 public:
825   using OpConversionPattern::OpConversionPattern;
826   LogicalResult
827   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
828                   ConversionPatternRewriter &rewriter) const override {
829     if (op.getHasInserts()) {
830       // Finalize any pending insertions.
831       StringRef name = "endInsert";
832       createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
833                      EmitCInterface::Off);
834     }
835     rewriter.replaceOp(op, adaptor.getOperands());
836     return success();
837   }
838 };
839 
840 /// Sparse conversion rule for the insertion operator.
841 class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
842 public:
843   using OpConversionPattern::OpConversionPattern;
844   LogicalResult
845   matchAndRewrite(InsertOp op, OpAdaptor adaptor,
846                   ConversionPatternRewriter &rewriter) const override {
847     // Note that the current regime only allows for strict lexicographic
848     // coordinate order. All values are passed by reference through stack
849     // allocated memrefs.
850     Location loc = op->getLoc();
851     const auto stt = getSparseTensorType(op.getTensor());
852     const auto elemTp = stt.getElementType();
853     const Level lvlRank = stt.getLvlRank();
854     auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
855     auto vref = genAllocaScalar(rewriter, loc, elemTp);
856     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
857     rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref);
858     SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
859     createFuncCall(rewriter, loc, name, {},
860                    {adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On);
861     rewriter.replaceOp(op, adaptor.getTensor());
862     return success();
863   }
864 };
865 
866 /// Sparse conversion rule for the expand operator.
867 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
868 public:
869   using OpConversionPattern::OpConversionPattern;
870   LogicalResult
871   matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
872                   ConversionPatternRewriter &rewriter) const override {
873     Location loc = op->getLoc();
874     const auto srcTp = getSparseTensorType(op.getTensor());
875     Type eltType = srcTp.getElementType();
876     Type boolType = rewriter.getIntegerType(1);
877     Type idxType = rewriter.getIndexType();
878     // All initialization should be done on entry of the loop nest.
879     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
880     // Get the cardinality of valid coordinates for the innermost level.
881     Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
882                                    srcTp.getLvlRank() - 1);
883     // Allocate temporary buffers for values, filled-switch, and coordinates.
884     // We do not use stack buffers for this, since the expanded size may
885     // be rather large (as it envelops a single expanded dense dimension).
886     Value values = genAlloc(rewriter, loc, sz, eltType);
887     Value filled = genAlloc(rewriter, loc, sz, boolType);
888     Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
889     Value zero = constantZero(rewriter, loc, idxType);
890     // Reset the values/filled-switch to all-zero/false. Note that this
891     // introduces an O(N) operation into the computation, but this reset
892     // operation is amortized over the innermost loops for the access
893     // pattern expansion. As noted in the operation doc, we would like
894     // to amortize this setup cost even between kernels.
895     rewriter.create<linalg::FillOp>(
896         loc, ValueRange{constantZero(rewriter, loc, eltType)},
897         ValueRange{values});
898     rewriter.create<linalg::FillOp>(
899         loc, ValueRange{constantZero(rewriter, loc, boolType)},
900         ValueRange{filled});
901     // Replace expansion op with these buffers and initial coordinate.
902     assert(op.getNumResults() == 4);
903     rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
904     return success();
905   }
906 };
907 
908 /// Sparse conversion rule for the compress operator.
909 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
910 public:
911   using OpConversionPattern::OpConversionPattern;
912   LogicalResult
913   matchAndRewrite(CompressOp op, OpAdaptor adaptor,
914                   ConversionPatternRewriter &rewriter) const override {
915     Location loc = op->getLoc();
916     // Note that this method call resets the values/filled-switch back to
917     // all-zero/false by only iterating over the set elements, so the
918     // complexity remains proportional to the sparsity of the expanded
919     // access pattern.
920     Value values = adaptor.getValues();
921     Value filled = adaptor.getFilled();
922     Value added = adaptor.getAdded();
923     Value count = adaptor.getCount();
924     Value tensor = adaptor.getTensor();
925     const auto stt = getSparseTensorType(op.getTensor());
926     const Type elemTp = stt.getElementType();
927     const Level lvlRank = stt.getLvlRank();
928     auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
929     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
930     SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
931     createFuncCall(rewriter, loc, name, {},
932                    {tensor, lvlCoords, values, filled, added, count},
933                    EmitCInterface::On);
934     rewriter.replaceOp(op, adaptor.getTensor());
935     // Deallocate the buffers on exit of the loop nest.
936     Operation *parent = getTop(op);
937     rewriter.setInsertionPointAfter(parent);
938     rewriter.create<memref::DeallocOp>(loc, values);
939     rewriter.create<memref::DeallocOp>(loc, filled);
940     rewriter.create<memref::DeallocOp>(loc, added);
941     return success();
942   }
943 };
944 
945 /// Sparse conversion rule for the output operator.
946 class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
947 public:
948   using OpConversionPattern::OpConversionPattern;
949   LogicalResult
950   matchAndRewrite(OutOp op, OpAdaptor adaptor,
951                   ConversionPatternRewriter &rewriter) const override {
952     const Location loc = op->getLoc();
953     const auto srcTp = getSparseTensorType(op.getTensor());
954     // Convert to default permuted COO.
955     Value src = adaptor.getOperands()[0];
956     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
957     Value coo = NewCallParams(rewriter, loc)
958                     .genBuffers(srcTp.withoutDimToLvl(), dimSizes)
959                     .genNewCall(Action::kToCOO, src);
960     // Then output the tensor to external file with coordinates in the
961     // externally visible lexicographic coordinate order.  A sort is
962     // required if the source was not in that order yet (note that the
963     // sort can be dropped altogether if external format does not care
964     // about the order at all, but here we assume it does).
965     const Value sort = constantI1(rewriter, loc, !srcTp.isIdentity());
966     SmallVector<Value, 3> outParams{coo, adaptor.getOperands()[1], sort};
967     const Type elemTp = srcTp.getElementType();
968     SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(elemTp)};
969     createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off);
970     genDelCOOCall(rewriter, loc, elemTp, coo);
971     rewriter.eraseOp(op);
972     return success();
973   }
974 };
975 
976 /// Sparse conversion rule for the sparse_tensor.pack operator.
977 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
978 public:
979   using OpConversionPattern::OpConversionPattern;
980   LogicalResult
981   matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
982                   ConversionPatternRewriter &rewriter) const override {
983     const Location loc = op->getLoc();
984     const auto dstTp = getSparseTensorType(op.getResult());
985     // AssembleOps always returns a static shaped tensor result.
986     assert(dstTp.hasStaticDimShape());
987     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
988     Value dst =
989         NewCallParams(rewriter, loc)
990             .genBuffers(dstTp.withoutDimToLvl(), dimSizes)
991             .genNewCall(Action::kPack,
992                         genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
993                                           adaptor.getValues()));
994     rewriter.replaceOp(op, dst);
995     return success();
996   }
997 };
998 
999 } // namespace
1000 
1001 //===----------------------------------------------------------------------===//
1002 // Sparse tensor type conversion into opaque pointer.
1003 //===----------------------------------------------------------------------===//
1004 
1005 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
1006   addConversion([](Type type) { return type; });
1007   addConversion(convertSparseTensorTypes);
1008 }
1009 
1010 //===----------------------------------------------------------------------===//
1011 // Public method for populating conversion rules.
1012 //===----------------------------------------------------------------------===//
1013 
1014 /// Populates the given patterns list with conversion rules required for
1015 /// the sparsification of linear algebra operations.
1016 void mlir::populateSparseTensorConversionPatterns(
1017     TypeConverter &typeConverter, RewritePatternSet &patterns,
1018     const SparseTensorConversionOptions &options) {
1019   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
1020                SparseCastConverter, SparseTensorNewConverter,
1021                SparseTensorAllocConverter, SparseTensorEmptyConverter,
1022                SparseTensorDeallocConverter, SparseTensorToPositionsConverter,
1023                SparseTensorToCoordinatesConverter,
1024                SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
1025                SparseTensorLoadConverter, SparseTensorInsertConverter,
1026                SparseTensorExpandConverter, SparseTensorCompressConverter,
1027                SparseTensorOutConverter, SparseTensorAssembleConverter>(
1028       typeConverter, patterns.getContext());
1029   patterns.add<SparseTensorConvertConverter>(typeConverter,
1030                                              patterns.getContext(), options);
1031 }
1032