xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (revision d392073f6747e4c522d6c6a3c49eb42859312034)
1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor primitives into calls into a runtime
10 // support library. Sparse tensor types are converted into opaque pointers
11 // to the underlying sparse storage schemes. The use of opaque pointers
12 // together with runtime support library keeps the conversion relatively
13 // simple, but at the expense of IR opacity, which obscures opportunities
14 // for subsequent optimization of the IR. An alternative is provided by
15 // the SparseTensorCodegen pass.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "CodegenUtils.h"
20 
21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
30 #include "mlir/Dialect/Tensor/IR/Tensor.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 
33 using namespace mlir;
34 using namespace mlir::sparse_tensor;
35 
36 namespace {
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Maps each sparse tensor type to an opaque pointer.
43 static std::optional<Type> convertSparseTensorTypes(Type type) {
44   if (getSparseTensorEncoding(type) != nullptr)
45     return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
46   return std::nullopt;
47 }
48 
49 /// Replaces the `op` with a `CallOp` to the `getFunc()` function reference.
50 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
51                                           StringRef name, TypeRange resultType,
52                                           ValueRange operands,
53                                           EmitCInterface emitCInterface) {
54   auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands,
55                     emitCInterface);
56   return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
57                                                    operands);
58 }
59 
60 /// Generates call to lookup a level-size.  N.B., this only generates
61 /// the raw function call, and therefore (intentionally) does not perform
62 /// any dim<->lvl conversion or other logic.
63 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
64                             uint64_t lvl) {
65   StringRef name = "sparseLvlSize";
66   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
67   Type iTp = builder.getIndexType();
68   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
69       .getResult(0);
70 }
71 
72 /// Generates call to lookup a dimension-size.  N.B., this only generates
73 /// the raw function call, and therefore (intentionally) does not perform
74 /// any dim<->lvl conversion or other logic.
75 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
76                             uint64_t dim) {
77   StringRef name = "sparseDimSize";
78   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
79   Type iTp = builder.getIndexType();
80   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
81       .getResult(0);
82 }
83 
84 /// Looks up a level-size by returning a statically-computed constant
85 /// (when possible), or by calling `genLvlSizeCall` (when dynamic).
86 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
87                                  SparseTensorType stt, Value tensor,
88                                  Level lvl) {
89   // Only sparse tensors have "levels" to query.
90   assert(stt.hasEncoding());
91   // TODO: The following implementation only handles permutations;
92   // we'll need to generalize this to handle arbitrary AffineExpr.
93   //
94   // There's no need to assert `isPermutation` here: because
95   // `getDimPosition` checks that the expr isa `AffineDimExpr`,
96   // which is all we care about (for supporting permutations).
97   const Dimension dim =
98       stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
99   if (const auto sz = stt.getStaticDimSize(dim))
100     return constantIndex(builder, loc, *sz);
101   // If we cannot statically compute the size from the shape, then we
102   // must dynamically query it.  (In principle we could also dynamically
103   // compute it, but since we already did so to construct the `tensor`
104   // in the first place, we might as well query rather than recompute.)
105   return genLvlSizeCall(builder, loc, tensor, lvl);
106 }
107 
108 /// Looks up a dimension-size by returning a constant from the shape
109 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
110 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
111 /// of dense tensors).
112 static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
113                                  SparseTensorType stt, Value tensor,
114                                  Dimension dim) {
115   if (const auto sz = stt.getStaticDimSize(dim))
116     return constantIndex(builder, loc, *sz);
117   if (stt.hasEncoding())
118     return genDimSizeCall(builder, loc, tensor, dim);
119   return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
120 }
121 
122 /// Populates the array with the dimension-sizes of the given tensor.
123 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
124                          Value tensor, SmallVectorImpl<Value> &out) {
125   const Dimension dimRank = stt.getDimRank();
126   out.clear();
127   out.reserve(dimRank);
128   for (Dimension d = 0; d < dimRank; d++)
129     out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
130 }
131 
132 /// Returns an array with the dimension-sizes of the given tensor.
133 /// If the *tensor* parameters is null, the tensor type is assumed to have a
134 /// static shape.
135 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
136                                       SparseTensorType stt,
137                                       Value tensor = Value()) {
138   SmallVector<Value> out;
139   fillDimSizes(builder, loc, stt, tensor, out);
140   return out;
141 }
142 
143 /// Generates an uninitialized buffer of the given size and type,
144 /// but returns it as type `memref<? x $tp>` (rather than as type
145 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
146 /// this buffer must be explicitly deallocated by client.
147 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
148   auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
149   return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
150 }
151 
152 /// Generates a temporary buffer for the level-types of the given encoding.
153 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
154                                SparseTensorType stt) {
155   SmallVector<Value> lvlTypes;
156   lvlTypes.reserve(stt.getLvlRank());
157   for (const auto dlt : stt.getEncoding().getLvlTypes())
158     lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
159   return allocaBuffer(builder, loc, lvlTypes);
160 }
161 
162 /// Extracts the bare (aligned) pointers that point to the tensor.
163 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
164                                       Value tensor) {
165   auto buf = genToMemref(builder, loc, tensor);
166   return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf);
167 }
168 
169 /// Generates a temporary buffer for the level-types of the given encoding.
170 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
171                                ValueRange lvlTensors, Value valTensor) {
172   SmallVector<Value> lvlBarePtrs;
173   lvlBarePtrs.reserve(lvlTensors.size() + 1);
174   // Passing in lvl buffer pointers.
175   for (const auto lvl : lvlTensors)
176     lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
177 
178   // Passing in value buffer pointers.
179   lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
180   Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>(
181       loc, allocaBuffer(builder, loc, lvlBarePtrs));
182   Value idxCast =
183       builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr);
184   return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder),
185                                           idxCast);
186 }
187 
188 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
189 /// the "swiss army knife" method of the sparse runtime support library
190 /// for materializing sparse tensors into the computation. This abstraction
191 /// reduces the need for modifications when the API changes.
192 class NewCallParams final {
193 public:
194   /// Allocates the `ValueRange` for the `func::CallOp` parameters.
195   NewCallParams(OpBuilder &builder, Location loc)
196       : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
197 
198   /// Initializes all static parameters (i.e., those which indicate
199   /// type-level information such as the encoding and sizes), generating
200   /// MLIR buffers as needed, and returning `this` for method chaining.
201   NewCallParams &genBuffers(SparseTensorType stt,
202                             ArrayRef<Value> dimSizesValues,
203                             Value dimSizesBuffer = Value()) {
204     assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
205     // Sparsity annotations.
206     params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
207     // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
208     params[kParamDimSizes] = dimSizesBuffer
209                                  ? dimSizesBuffer
210                                  : allocaBuffer(builder, loc, dimSizesValues);
211     params[kParamLvlSizes] =
212         genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes],
213                       params[kParamDim2Lvl], params[kParamLvl2Dim]);
214     // Secondary and primary types encoding.
215     const auto enc = stt.getEncoding();
216     params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
217     params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc);
218     params[kParamValTp] =
219         constantPrimaryTypeEncoding(builder, loc, stt.getElementType());
220     // Return `this` for method chaining.
221     return *this;
222   }
223 
224   /// Checks whether all the static parameters have been initialized.
225   bool isInitialized() const {
226     for (unsigned i = 0; i < kNumStaticParams; ++i)
227       if (!params[i])
228         return false;
229     return true;
230   }
231 
232   /// Generates a function call, with the current static parameters
233   /// and the given dynamic arguments.
234   Value genNewCall(Action action, Value ptr = Value()) {
235     assert(isInitialized() && "Must initialize before genNewCall");
236     StringRef name = "newSparseTensor";
237     params[kParamAction] = constantAction(builder, loc, action);
238     params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp);
239     return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
240         .getResult(0);
241   }
242 
243 private:
244   static constexpr unsigned kNumStaticParams = 8;
245   static constexpr unsigned kNumDynamicParams = 2;
246   static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
247   static constexpr unsigned kParamDimSizes = 0;
248   static constexpr unsigned kParamLvlSizes = 1;
249   static constexpr unsigned kParamLvlTypes = 2;
250   static constexpr unsigned kParamDim2Lvl = 3;
251   static constexpr unsigned kParamLvl2Dim = 4;
252   static constexpr unsigned kParamPosTp = 5;
253   static constexpr unsigned kParamCrdTp = 6;
254   static constexpr unsigned kParamValTp = 7;
255   static constexpr unsigned kParamAction = 8;
256   static constexpr unsigned kParamPtr = 9;
257 
258   OpBuilder &builder;
259   Location loc;
260   Type pTp;
261   Value params[kNumParams];
262 };
263 
264 /// Generates a call to obtain the values array.
265 static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
266                            ValueRange ptr) {
267   SmallString<15> name{"sparseValues",
268                        primaryTypeFunctionSuffix(tp.getElementType())};
269   return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
270       .getResult(0);
271 }
272 
273 /// Generates a call to release/delete a `SparseTensorCOO`.
274 static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp,
275                           Value coo) {
276   SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)};
277   createFuncCall(builder, loc, name, {}, coo, EmitCInterface::Off);
278 }
279 
280 //===----------------------------------------------------------------------===//
281 // Conversion rules.
282 //===----------------------------------------------------------------------===//
283 
284 /// Sparse conversion rule for returns.
285 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
286 public:
287   using OpConversionPattern::OpConversionPattern;
288   LogicalResult
289   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
290                   ConversionPatternRewriter &rewriter) const override {
291     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
292     return success();
293   }
294 };
295 
296 /// Sparse conversion rule for accessing dimension-sizes.
297 class SparseTensorToDimSizeConverter
298     : public OpConversionPattern<tensor::DimOp> {
299 public:
300   using OpConversionPattern::OpConversionPattern;
301   LogicalResult
302   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
303                   ConversionPatternRewriter &rewriter) const override {
304     const auto stt = getSparseTensorType(op.getSource());
305     // Only rewrite sparse DimOp.
306     if (!stt.hasEncoding())
307       return failure();
308     // Only rewrite DimOp with constant index.
309     std::optional<int64_t> dim = op.getConstantIndex();
310     if (!dim)
311       return failure();
312     // Generate the call.
313     Value src = adaptor.getOperands()[0];
314     rewriter.replaceOp(
315         op, createOrFoldDimCall(rewriter, op->getLoc(), stt, src, *dim));
316     return success();
317   }
318 };
319 
320 /// Sparse conversion rule for trivial tensor casts.
321 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
322 public:
323   using OpConversionPattern::OpConversionPattern;
324   LogicalResult
325   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
326                   ConversionPatternRewriter &rewriter) const override {
327     // Only rewrite identically annotated source/dest.
328     auto encDst = getSparseTensorEncoding(op.getType());
329     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
330     if (!encDst || encDst != encSrc)
331       return failure();
332     rewriter.replaceOp(op, adaptor.getOperands());
333     return success();
334   }
335 };
336 
337 /// Sparse conversion rule for the new operator.
338 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
339 public:
340   using OpConversionPattern::OpConversionPattern;
341   LogicalResult
342   matchAndRewrite(NewOp op, OpAdaptor adaptor,
343                   ConversionPatternRewriter &rewriter) const override {
344     Location loc = op.getLoc();
345     const auto stt = getSparseTensorType(op);
346     if (!stt.hasEncoding())
347       return failure();
348     // Construct the `reader` opening method calls.
349     SmallVector<Value> dimShapesValues;
350     Value dimSizesBuffer;
351     Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
352                              dimShapesValues, dimSizesBuffer);
353     // Use the `reader` to parse the file.
354     Value tensor = NewCallParams(rewriter, loc)
355                        .genBuffers(stt, dimShapesValues, dimSizesBuffer)
356                        .genNewCall(Action::kFromReader, reader);
357     // Free the memory for `reader`.
358     createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
359                    EmitCInterface::Off);
360     rewriter.replaceOp(op, tensor);
361     return success();
362   }
363 };
364 
365 /// Sparse conversion rule for the alloc operator.
366 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
367 class SparseTensorAllocConverter
368     : public OpConversionPattern<bufferization::AllocTensorOp> {
369 public:
370   using OpConversionPattern::OpConversionPattern;
371   LogicalResult
372   matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
373                   ConversionPatternRewriter &rewriter) const override {
374     if (op.getCopy())
375       return rewriter.notifyMatchFailure(op,
376                                          "sparse tensor copy not implemented");
377     Location loc = op.getLoc();
378     const auto stt = getSparseTensorType(op);
379     if (!stt.hasEncoding())
380       return failure();
381     // Gather all dimension sizes as SSA values.
382     const Dimension dimRank = stt.getDimRank();
383     SmallVector<Value> dimSizes;
384     dimSizes.reserve(dimRank);
385     unsigned operandCtr = 0;
386     for (Dimension d = 0; d < dimRank; ++d) {
387       dimSizes.push_back(
388           stt.isDynamicDim(d)
389               ? adaptor.getOperands()[operandCtr++]
390               : constantIndex(rewriter, loc, op.getStaticSize(d)));
391     }
392     // Generate the call to construct empty tensor. The sizes are
393     // explicitly defined by the arguments to the alloc operator.
394     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
395                                .genBuffers(stt, dimSizes)
396                                .genNewCall(Action::kEmpty));
397     return success();
398   }
399 };
400 
401 /// Sparse conversion rule for the empty tensor.
402 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
403 public:
404   using OpConversionPattern::OpConversionPattern;
405   LogicalResult
406   matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
407                   ConversionPatternRewriter &rewriter) const override {
408     Location loc = op.getLoc();
409     const auto stt = getSparseTensorType(op);
410     if (!stt.hasEncoding())
411       return failure();
412     // Gather all dimension sizes as SSA values.
413     const Dimension dimRank = stt.getDimRank();
414     SmallVector<Value> dimSizes;
415     dimSizes.reserve(dimRank);
416     auto shape = op.getType().getShape();
417     unsigned operandCtr = 0;
418     for (Dimension d = 0; d < dimRank; ++d) {
419       dimSizes.push_back(stt.isDynamicDim(d)
420                              ? adaptor.getOperands()[operandCtr++]
421                              : constantIndex(rewriter, loc, shape[d]));
422     }
423     // Generate the call to construct empty tensor. The sizes are
424     // explicitly defined by the arguments to the alloc operator.
425     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
426                                .genBuffers(stt, dimSizes)
427                                .genNewCall(Action::kEmpty));
428     return success();
429   }
430 };
431 
432 /// Sparse conversion rule for the convert operator.
433 class SparseTensorReorderCOOConverter
434     : public OpConversionPattern<ReorderCOOOp> {
435 public:
436   using OpConversionPattern::OpConversionPattern;
437 
438   LogicalResult
439   matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
440                   ConversionPatternRewriter &rewriter) const override {
441     const Location loc = op->getLoc();
442     const auto srcTp = getSparseTensorType(op.getInputCoo());
443     const auto dstTp = getSparseTensorType(op);
444 
445     const Value src = adaptor.getInputCoo();
446 
447     NewCallParams params(rewriter, loc);
448     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
449     rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizes)
450                                .genNewCall(Action::kSortCOOInPlace, src));
451 
452     return success();
453   }
454 };
455 
456 /// Sparse conversion rule for the dealloc operator.
457 class SparseTensorDeallocConverter
458     : public OpConversionPattern<bufferization::DeallocTensorOp> {
459 public:
460   using OpConversionPattern::OpConversionPattern;
461   LogicalResult
462   matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
463                   ConversionPatternRewriter &rewriter) const override {
464     if (!getSparseTensorType(op.getTensor()).hasEncoding())
465       return failure();
466     StringRef name = "delSparseTensor";
467     createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
468                    EmitCInterface::Off);
469     rewriter.eraseOp(op);
470     return success();
471   }
472 };
473 
474 /// Sparse conversion rule for position accesses.
475 class SparseTensorToPositionsConverter
476     : public OpConversionPattern<ToPositionsOp> {
477 public:
478   using OpConversionPattern::OpConversionPattern;
479   LogicalResult
480   matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
481                   ConversionPatternRewriter &rewriter) const override {
482     Type resTp = op.getType();
483     Type posTp = cast<ShapedType>(resTp).getElementType();
484     SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
485     Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel());
486     replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl},
487                           EmitCInterface::On);
488     return success();
489   }
490 };
491 
492 /// Sparse conversion rule for coordinate accesses.
493 class SparseTensorToCoordinatesConverter
494     : public OpConversionPattern<ToCoordinatesOp> {
495 public:
496   using OpConversionPattern::OpConversionPattern;
497   LogicalResult
498   matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
499                   ConversionPatternRewriter &rewriter) const override {
500     // TODO: use `SparseTensorType::getCrdType` instead.
501     Type resType = op.getType();
502     const Type crdTp = cast<ShapedType>(resType).getElementType();
503     SmallString<19> name{"sparseCoordinates",
504                          overheadTypeFunctionSuffix(crdTp)};
505     Location loc = op->getLoc();
506     Value lvl = constantIndex(rewriter, loc, op.getLevel());
507 
508     // The function returns a MemRef without a layout.
509     MemRefType callRetType = get1DMemRefType(crdTp, false);
510     SmallVector<Value> operands{adaptor.getTensor(), lvl};
511     auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType,
512                       operands, EmitCInterface::On);
513     Value callRet =
514         rewriter.create<func::CallOp>(loc, callRetType, fn, operands)
515             .getResult(0);
516 
517     // Cast the MemRef type to the type expected by the users, though these
518     // two types should be compatible at runtime.
519     if (resType != callRetType)
520       callRet = rewriter.create<memref::CastOp>(loc, resType, callRet);
521     rewriter.replaceOp(op, callRet);
522 
523     return success();
524   }
525 };
526 
527 /// Sparse conversion rule for value accesses.
528 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
529 public:
530   using OpConversionPattern::OpConversionPattern;
531   LogicalResult
532   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
533                   ConversionPatternRewriter &rewriter) const override {
534     auto resType = cast<ShapedType>(op.getType());
535     rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
536                                          adaptor.getOperands()));
537     return success();
538   }
539 };
540 
541 /// Sparse conversion rule for number of entries operator.
542 class SparseNumberOfEntriesConverter
543     : public OpConversionPattern<NumberOfEntriesOp> {
544 public:
545   using OpConversionPattern::OpConversionPattern;
546   LogicalResult
547   matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
548                   ConversionPatternRewriter &rewriter) const override {
549     Location loc = op.getLoc();
550     // Query values array size for the actually stored values size.
551     Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType();
552     auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType);
553     Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
554     rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
555                                                constantIndex(rewriter, loc, 0));
556     return success();
557   }
558 };
559 
560 /// Sparse conversion rule for tensor rematerialization.
561 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
562 public:
563   using OpConversionPattern::OpConversionPattern;
564   LogicalResult
565   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
566                   ConversionPatternRewriter &rewriter) const override {
567     if (op.getHasInserts()) {
568       // Finalize any pending insertions.
569       StringRef name = "endLexInsert";
570       createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
571                      EmitCInterface::Off);
572     }
573     rewriter.replaceOp(op, adaptor.getOperands());
574     return success();
575   }
576 };
577 
578 /// Sparse conversion rule for the insertion operator.
579 class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
580 public:
581   using OpConversionPattern::OpConversionPattern;
582   LogicalResult
583   matchAndRewrite(InsertOp op, OpAdaptor adaptor,
584                   ConversionPatternRewriter &rewriter) const override {
585     // Note that the current regime only allows for strict lexicographic
586     // coordinate order. All values are passed by reference through stack
587     // allocated memrefs.
588     Location loc = op->getLoc();
589     const auto stt = getSparseTensorType(op.getTensor());
590     const auto elemTp = stt.getElementType();
591     const Level lvlRank = stt.getLvlRank();
592     auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
593     auto vref = genAllocaScalar(rewriter, loc, elemTp);
594     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
595     rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref);
596     SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
597     createFuncCall(rewriter, loc, name, {},
598                    {adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On);
599     rewriter.replaceOp(op, adaptor.getTensor());
600     return success();
601   }
602 };
603 
604 /// Sparse conversion rule for the expand operator.
605 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
606 public:
607   using OpConversionPattern::OpConversionPattern;
608   LogicalResult
609   matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
610                   ConversionPatternRewriter &rewriter) const override {
611     Location loc = op->getLoc();
612     const auto srcTp = getSparseTensorType(op.getTensor());
613     Type eltType = srcTp.getElementType();
614     Type boolType = rewriter.getIntegerType(1);
615     Type idxType = rewriter.getIndexType();
616     // All initialization should be done on entry of the loop nest.
617     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
618     // Get the cardinality of valid coordinates for the innermost level.
619     Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
620                                    srcTp.getLvlRank() - 1);
621     // Allocate temporary buffers for values, filled-switch, and coordinates.
622     // We do not use stack buffers for this, since the expanded size may
623     // be rather large (as it envelops a single expanded dense dimension).
624     Value values = genAlloc(rewriter, loc, sz, eltType);
625     Value filled = genAlloc(rewriter, loc, sz, boolType);
626     Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
627     Value zero = constantZero(rewriter, loc, idxType);
628     // Reset the values/filled-switch to all-zero/false. Note that this
629     // introduces an O(N) operation into the computation, but this reset
630     // operation is amortized over the innermost loops for the access
631     // pattern expansion. As noted in the operation doc, we would like
632     // to amortize this setup cost even between kernels.
633     rewriter.create<linalg::FillOp>(
634         loc, ValueRange{constantZero(rewriter, loc, eltType)},
635         ValueRange{values});
636     rewriter.create<linalg::FillOp>(
637         loc, ValueRange{constantZero(rewriter, loc, boolType)},
638         ValueRange{filled});
639     // Replace expansion op with these buffers and initial coordinate.
640     assert(op.getNumResults() == 4);
641     rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
642     return success();
643   }
644 };
645 
646 /// Sparse conversion rule for the compress operator.
647 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
648 public:
649   using OpConversionPattern::OpConversionPattern;
650   LogicalResult
651   matchAndRewrite(CompressOp op, OpAdaptor adaptor,
652                   ConversionPatternRewriter &rewriter) const override {
653     Location loc = op->getLoc();
654     // Note that this method call resets the values/filled-switch back to
655     // all-zero/false by only iterating over the set elements, so the
656     // complexity remains proportional to the sparsity of the expanded
657     // access pattern.
658     Value values = adaptor.getValues();
659     Value filled = adaptor.getFilled();
660     Value added = adaptor.getAdded();
661     Value count = adaptor.getCount();
662     Value tensor = adaptor.getTensor();
663     const auto stt = getSparseTensorType(op.getTensor());
664     const Type elemTp = stt.getElementType();
665     const Level lvlRank = stt.getLvlRank();
666     auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
667     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
668     SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
669     createFuncCall(rewriter, loc, name, {},
670                    {tensor, lvlCoords, values, filled, added, count},
671                    EmitCInterface::On);
672     rewriter.replaceOp(op, adaptor.getTensor());
673     // Deallocate the buffers on exit of the loop nest.
674     Operation *parent = getTop(op);
675     rewriter.setInsertionPointAfter(parent);
676     rewriter.create<memref::DeallocOp>(loc, values);
677     rewriter.create<memref::DeallocOp>(loc, filled);
678     rewriter.create<memref::DeallocOp>(loc, added);
679     return success();
680   }
681 };
682 
683 /// Sparse conversion rule for the output operator.
684 class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
685 public:
686   using OpConversionPattern::OpConversionPattern;
687   LogicalResult
688   matchAndRewrite(OutOp op, OpAdaptor adaptor,
689                   ConversionPatternRewriter &rewriter) const override {
690     const Location loc = op->getLoc();
691     const auto srcTp = getSparseTensorType(op.getTensor());
692     // Convert to default permuted COO.
693     Value src = adaptor.getOperands()[0];
694     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
695     Value coo = NewCallParams(rewriter, loc)
696                     .genBuffers(srcTp.withoutDimToLvl(), dimSizes)
697                     .genNewCall(Action::kToCOO, src);
698     // Then output the tensor to external file with coordinates in the
699     // externally visible lexicographic coordinate order.  A sort is
700     // required if the source was not in that order yet (note that the
701     // sort can be dropped altogether if external format does not care
702     // about the order at all, but here we assume it does).
703     const Value sort = constantI1(rewriter, loc, !srcTp.isIdentity());
704     SmallVector<Value, 3> outParams{coo, adaptor.getOperands()[1], sort};
705     const Type elemTp = srcTp.getElementType();
706     SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(elemTp)};
707     createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off);
708     genDelCOOCall(rewriter, loc, elemTp, coo);
709     rewriter.eraseOp(op);
710     return success();
711   }
712 };
713 
714 /// Sparse conversion rule for the sparse_tensor.pack operator.
715 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
716 public:
717   using OpConversionPattern::OpConversionPattern;
718   LogicalResult
719   matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
720                   ConversionPatternRewriter &rewriter) const override {
721     const Location loc = op->getLoc();
722     const auto dstTp = getSparseTensorType(op.getResult());
723     // AssembleOps always returns a static shaped tensor result.
724     assert(dstTp.hasStaticDimShape());
725     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
726     Value dst =
727         NewCallParams(rewriter, loc)
728             .genBuffers(dstTp.withoutDimToLvl(), dimSizes)
729             .genNewCall(Action::kPack,
730                         genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
731                                           adaptor.getValues()));
732     rewriter.replaceOp(op, dst);
733     return success();
734   }
735 };
736 
737 } // namespace
738 
739 //===----------------------------------------------------------------------===//
740 // Sparse tensor type conversion into opaque pointer.
741 //===----------------------------------------------------------------------===//
742 
743 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
744   addConversion([](Type type) { return type; });
745   addConversion(convertSparseTensorTypes);
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // Public method for populating conversion rules.
750 //===----------------------------------------------------------------------===//
751 
752 /// Populates the given patterns list with conversion rules required for
753 /// the sparsification of linear algebra operations.
754 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
755                                                   RewritePatternSet &patterns) {
756   patterns
757       .add<SparseReturnConverter, SparseTensorToDimSizeConverter,
758            SparseCastConverter, SparseTensorNewConverter,
759            SparseTensorAllocConverter, SparseTensorEmptyConverter,
760            SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
761            SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
762            SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
763            SparseTensorLoadConverter, SparseTensorInsertConverter,
764            SparseTensorExpandConverter, SparseTensorCompressConverter,
765            SparseTensorOutConverter, SparseTensorAssembleConverter>(
766           typeConverter, patterns.getContext());
767 }
768