xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (revision f248d0b28dca451d9af74c1bfc8e681919a4d982)
1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor primitives into calls into a runtime
10 // support library. Sparse tensor types are converted into opaque pointers
11 // to the underlying sparse storage schemes. The use of opaque pointers
12 // together with runtime support library keeps the conversion relatively
13 // simple, but at the expense of IR opacity, which obscures opportunities
14 // for subsequent optimization of the IR. An alternative is provided by
15 // the SparseTensorCodegen pass.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "CodegenUtils.h"
20 
21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
30 #include "mlir/Dialect/Tensor/IR/Tensor.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 
33 using namespace mlir;
34 using namespace mlir::sparse_tensor;
35 
36 namespace {
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Maps each sparse tensor type to an opaque pointer.
43 static std::optional<Type> convertSparseTensorTypes(Type type) {
44   if (getSparseTensorEncoding(type) != nullptr)
45     return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
46   return std::nullopt;
47 }
48 
49 /// Replaces the `op` with a `CallOp` to the `getFunc()` function reference.
50 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
51                                           StringRef name, TypeRange resultType,
52                                           ValueRange operands,
53                                           EmitCInterface emitCInterface) {
54   auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands,
55                     emitCInterface);
56   return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
57                                                    operands);
58 }
59 
60 /// Generates call to lookup a level-size.  N.B., this only generates
61 /// the raw function call, and therefore (intentionally) does not perform
62 /// any dim<->lvl conversion or other logic.
63 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
64                             uint64_t lvl) {
65   StringRef name = "sparseLvlSize";
66   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
67   Type iTp = builder.getIndexType();
68   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
69       .getResult(0);
70 }
71 
72 /// Generates call to lookup a dimension-size.  N.B., this only generates
73 /// the raw function call, and therefore (intentionally) does not perform
74 /// any dim<->lvl conversion or other logic.
75 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
76                             uint64_t dim) {
77   StringRef name = "sparseDimSize";
78   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
79   Type iTp = builder.getIndexType();
80   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
81       .getResult(0);
82 }
83 
84 /// Looks up a level-size by returning a statically-computed constant
85 /// (when possible), or by calling `genLvlSizeCall` (when dynamic).
86 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
87                                  SparseTensorType stt, Value tensor,
88                                  Level lvl) {
89   // Only sparse tensors have "levels" to query.
90   assert(stt.hasEncoding());
91   // TODO: The following implementation only handles permutations;
92   // we'll need to generalize this to handle arbitrary AffineExpr.
93   //
94   // There's no need to assert `isPermutation` here: because
95   // `getDimPosition` checks that the expr isa `AffineDimExpr`,
96   // which is all we care about (for supporting permutations).
97   const Dimension dim =
98       stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
99   if (const auto sz = stt.getStaticDimSize(dim))
100     return constantIndex(builder, loc, *sz);
101   // If we cannot statically compute the size from the shape, then we
102   // must dynamically query it.  (In principle we could also dynamically
103   // compute it, but since we already did so to construct the `tensor`
104   // in the first place, we might as well query rather than recompute.)
105   return genLvlSizeCall(builder, loc, tensor, lvl);
106 }
107 
108 /// Looks up a dimension-size by returning a constant from the shape
109 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
110 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
111 /// of dense tensors).
112 static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
113                                  SparseTensorType stt, Value tensor,
114                                  Dimension dim) {
115   if (const auto sz = stt.getStaticDimSize(dim))
116     return constantIndex(builder, loc, *sz);
117   if (stt.hasEncoding())
118     return genDimSizeCall(builder, loc, tensor, dim);
119   return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
120 }
121 
122 /// Populates the array with the dimension-sizes of the given tensor.
123 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
124                          Value tensor, SmallVectorImpl<Value> &out) {
125   const Dimension dimRank = stt.getDimRank();
126   out.clear();
127   out.reserve(dimRank);
128   for (Dimension d = 0; d < dimRank; d++)
129     out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
130 }
131 
132 /// Returns an array with the dimension-sizes of the given tensor.
133 /// If the *tensor* parameters is null, the tensor type is assumed to have a
134 /// static shape.
135 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
136                                       SparseTensorType stt,
137                                       Value tensor = Value()) {
138   SmallVector<Value> out;
139   fillDimSizes(builder, loc, stt, tensor, out);
140   return out;
141 }
142 
143 /// Generates an uninitialized buffer of the given size and type,
144 /// but returns it as type `memref<? x $tp>` (rather than as type
145 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
146 /// this buffer must be explicitly deallocated by client.
147 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
148   auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
149   return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
150 }
151 
152 /// Generates a temporary buffer for the level-types of the given encoding.
153 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
154                                SparseTensorType stt) {
155   SmallVector<Value> lvlTypes;
156   lvlTypes.reserve(stt.getLvlRank());
157   for (const auto dlt : stt.getEncoding().getLvlTypes())
158     lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
159   return allocaBuffer(builder, loc, lvlTypes);
160 }
161 
162 /// Extracts the bare (aligned) pointers that point to the tensor.
163 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
164                                       Value tensor) {
165   auto buf = genToMemref(builder, loc, tensor);
166   return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf);
167 }
168 
169 /// Generates a temporary buffer for the level-types of the given encoding.
170 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
171                                ValueRange lvlTensors, Value valTensor) {
172   SmallVector<Value> lvlBarePtrs;
173   lvlBarePtrs.reserve(lvlTensors.size() + 1);
174   // Passing in lvl buffer pointers.
175   for (const auto lvl : lvlTensors)
176     lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
177 
178   // Passing in value buffer pointers.
179   lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
180   Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>(
181       loc, allocaBuffer(builder, loc, lvlBarePtrs));
182   Value idxCast =
183       builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr);
184   return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder),
185                                           idxCast);
186 }
187 
188 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
189 /// the "swiss army knife" method of the sparse runtime support library
190 /// for materializing sparse tensors into the computation. This abstraction
191 /// reduces the need for modifications when the API changes.
192 class NewCallParams final {
193 public:
194   /// Allocates the `ValueRange` for the `func::CallOp` parameters.
195   NewCallParams(OpBuilder &builder, Location loc)
196       : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
197 
198   /// Initializes all static parameters (i.e., those which indicate
199   /// type-level information such as the encoding and sizes), generating
200   /// MLIR buffers as needed, and returning `this` for method chaining.
201   NewCallParams &genBuffers(SparseTensorType stt,
202                             ArrayRef<Value> dimSizesValues) {
203     assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
204     // Sparsity annotations.
205     params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
206     // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
207     params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizesValues);
208     params[kParamLvlSizes] =
209         genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes],
210                       params[kParamDim2Lvl], params[kParamLvl2Dim]);
211     // Secondary and primary types encoding.
212     setTemplateTypes(stt);
213     // Finally, make note that initialization is complete.
214     assert(isInitialized() && "Initialization failed");
215     // And return `this` for method chaining.
216     return *this;
217   }
218 
219   /// (Re)sets the C++ template type parameters, and returns `this`
220   /// for method chaining. This is already done as part of `genBuffers`,
221   /// but is factored out so that it can also be called independently
222   /// whenever subsequent `genNewCall` calls want to reuse the same
223   /// buffers but different type parameters.
224   //
225   // TODO: This is only ever used by sparse2sparse-viaCOO `ConvertOp`;
226   // is there a better way to handle that than this one-off setter method?
227   NewCallParams &setTemplateTypes(SparseTensorType stt) {
228     const auto enc = stt.getEncoding();
229     params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
230     params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc);
231     params[kParamValTp] =
232         constantPrimaryTypeEncoding(builder, loc, stt.getElementType());
233     return *this;
234   }
235 
236   /// Checks whether all the static parameters have been initialized.
237   bool isInitialized() const {
238     for (unsigned i = 0; i < kNumStaticParams; ++i)
239       if (!params[i])
240         return false;
241     return true;
242   }
243 
244   /// Gets the dimension-to-level mapping.
245   //
246   // TODO: This is only ever used for passing into `genAddEltCall`;
247   // is there a better way to encapsulate that pattern (both to avoid
248   // this one-off getter, and to avoid potential mixups)?
249   Value getDimToLvl() const {
250     assert(isInitialized() && "Must initialize before getDimToLvl");
251     return params[kParamDim2Lvl];
252   }
253 
254   /// Generates a function call, with the current static parameters
255   /// and the given dynamic arguments.
256   Value genNewCall(Action action, Value ptr = Value()) {
257     assert(isInitialized() && "Must initialize before genNewCall");
258     StringRef name = "newSparseTensor";
259     params[kParamAction] = constantAction(builder, loc, action);
260     params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp);
261     return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
262         .getResult(0);
263   }
264 
265 private:
266   static constexpr unsigned kNumStaticParams = 8;
267   static constexpr unsigned kNumDynamicParams = 2;
268   static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
269   static constexpr unsigned kParamDimSizes = 0;
270   static constexpr unsigned kParamLvlSizes = 1;
271   static constexpr unsigned kParamLvlTypes = 2;
272   static constexpr unsigned kParamDim2Lvl = 3;
273   static constexpr unsigned kParamLvl2Dim = 4;
274   static constexpr unsigned kParamPosTp = 5;
275   static constexpr unsigned kParamCrdTp = 6;
276   static constexpr unsigned kParamValTp = 7;
277   static constexpr unsigned kParamAction = 8;
278   static constexpr unsigned kParamPtr = 9;
279 
280   OpBuilder &builder;
281   Location loc;
282   Type pTp;
283   Value params[kNumParams];
284 };
285 
286 /// Generates a call to obtain the values array.
287 static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
288                            ValueRange ptr) {
289   SmallString<15> name{"sparseValues",
290                        primaryTypeFunctionSuffix(tp.getElementType())};
291   return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
292       .getResult(0);
293 }
294 
295 /// Generates a call to release/delete a `SparseTensorCOO`.
296 static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp,
297                           Value coo) {
298   SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)};
299   createFuncCall(builder, loc, name, {}, coo, EmitCInterface::Off);
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // Conversion rules.
304 //===----------------------------------------------------------------------===//
305 
306 /// Sparse conversion rule for returns.
307 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
308 public:
309   using OpConversionPattern::OpConversionPattern;
310   LogicalResult
311   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
312                   ConversionPatternRewriter &rewriter) const override {
313     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
314     return success();
315   }
316 };
317 
318 /// Sparse conversion rule for accessing dimension-sizes.
319 class SparseTensorToDimSizeConverter
320     : public OpConversionPattern<tensor::DimOp> {
321 public:
322   using OpConversionPattern::OpConversionPattern;
323   LogicalResult
324   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
325                   ConversionPatternRewriter &rewriter) const override {
326     const auto stt = getSparseTensorType(op.getSource());
327     // Only rewrite sparse DimOp.
328     if (!stt.hasEncoding())
329       return failure();
330     // Only rewrite DimOp with constant index.
331     std::optional<int64_t> dim = op.getConstantIndex();
332     if (!dim)
333       return failure();
334     // Generate the call.
335     Value src = adaptor.getOperands()[0];
336     rewriter.replaceOp(
337         op, createOrFoldDimCall(rewriter, op->getLoc(), stt, src, *dim));
338     return success();
339   }
340 };
341 
342 /// Sparse conversion rule for trivial tensor casts.
343 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
344 public:
345   using OpConversionPattern::OpConversionPattern;
346   LogicalResult
347   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
348                   ConversionPatternRewriter &rewriter) const override {
349     // Only rewrite identically annotated source/dest.
350     auto encDst = getSparseTensorEncoding(op.getType());
351     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
352     if (!encDst || encDst != encSrc)
353       return failure();
354     rewriter.replaceOp(op, adaptor.getOperands());
355     return success();
356   }
357 };
358 
359 /// Sparse conversion rule for the new operator.
360 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
361 public:
362   using OpConversionPattern::OpConversionPattern;
363   LogicalResult
364   matchAndRewrite(NewOp op, OpAdaptor adaptor,
365                   ConversionPatternRewriter &rewriter) const override {
366     Location loc = op.getLoc();
367     const auto stt = getSparseTensorType(op);
368     if (!stt.hasEncoding())
369       return failure();
370     // Construct the reader opening method calls.
371     SmallVector<Value> dimShapesValues;
372     Value dimSizesBuffer;
373     Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
374                              dimShapesValues, dimSizesBuffer);
375     // Now construct the lvlSizes, dim2lvl, and lvl2dim buffers.
376     Value dim2lvlBuffer;
377     Value lvl2dimBuffer;
378     Value lvlSizesBuffer =
379         genMapBuffers(rewriter, loc, stt, dimShapesValues, dimSizesBuffer,
380                       dim2lvlBuffer, lvl2dimBuffer);
381     // Use the `reader` to parse the file.
382     Type opaqueTp = getOpaquePointerType(rewriter);
383     Type eltTp = stt.getElementType();
384     Value valTp = constantPrimaryTypeEncoding(rewriter, loc, eltTp);
385     SmallVector<Value, 8> params{
386         reader,
387         lvlSizesBuffer,
388         genLvlTypesBuffer(rewriter, loc, stt),
389         dim2lvlBuffer,
390         lvl2dimBuffer,
391         constantPosTypeEncoding(rewriter, loc, stt.getEncoding()),
392         constantCrdTypeEncoding(rewriter, loc, stt.getEncoding()),
393         valTp};
394     Value tensor = createFuncCall(rewriter, loc, "newSparseTensorFromReader",
395                                   opaqueTp, params, EmitCInterface::On)
396                        .getResult(0);
397     // Free the memory for `reader`.
398     createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
399                    EmitCInterface::Off);
400     rewriter.replaceOp(op, tensor);
401     return success();
402   }
403 };
404 
405 /// Sparse conversion rule for the alloc operator.
406 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
407 class SparseTensorAllocConverter
408     : public OpConversionPattern<bufferization::AllocTensorOp> {
409 public:
410   using OpConversionPattern::OpConversionPattern;
411   LogicalResult
412   matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
413                   ConversionPatternRewriter &rewriter) const override {
414     if (op.getCopy())
415       return rewriter.notifyMatchFailure(op,
416                                          "sparse tensor copy not implemented");
417     Location loc = op.getLoc();
418     const auto stt = getSparseTensorType(op);
419     if (!stt.hasEncoding())
420       return failure();
421     // Gather all dimension sizes as SSA values.
422     const Dimension dimRank = stt.getDimRank();
423     SmallVector<Value> dimSizes;
424     dimSizes.reserve(dimRank);
425     unsigned operandCtr = 0;
426     for (Dimension d = 0; d < dimRank; ++d) {
427       dimSizes.push_back(
428           stt.isDynamicDim(d)
429               ? adaptor.getOperands()[operandCtr++]
430               : constantIndex(rewriter, loc, op.getStaticSize(d)));
431     }
432     // Generate the call to construct empty tensor. The sizes are
433     // explicitly defined by the arguments to the alloc operator.
434     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
435                                .genBuffers(stt, dimSizes)
436                                .genNewCall(Action::kEmpty));
437     return success();
438   }
439 };
440 
441 /// Sparse conversion rule for the empty tensor.
442 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
443 public:
444   using OpConversionPattern::OpConversionPattern;
445   LogicalResult
446   matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
447                   ConversionPatternRewriter &rewriter) const override {
448     Location loc = op.getLoc();
449     const auto stt = getSparseTensorType(op);
450     if (!stt.hasEncoding())
451       return failure();
452     // Gather all dimension sizes as SSA values.
453     const Dimension dimRank = stt.getDimRank();
454     SmallVector<Value> dimSizes;
455     dimSizes.reserve(dimRank);
456     auto shape = op.getType().getShape();
457     unsigned operandCtr = 0;
458     for (Dimension d = 0; d < dimRank; ++d) {
459       dimSizes.push_back(stt.isDynamicDim(d)
460                              ? adaptor.getOperands()[operandCtr++]
461                              : constantIndex(rewriter, loc, shape[d]));
462     }
463     // Generate the call to construct empty tensor. The sizes are
464     // explicitly defined by the arguments to the alloc operator.
465     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
466                                .genBuffers(stt, dimSizes)
467                                .genNewCall(Action::kEmpty));
468     return success();
469   }
470 };
471 
472 /// Sparse conversion rule for the convert operator.
473 class SparseTensorReorderCOOConverter
474     : public OpConversionPattern<ReorderCOOOp> {
475 public:
476   using OpConversionPattern::OpConversionPattern;
477 
478   LogicalResult
479   matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
480                   ConversionPatternRewriter &rewriter) const override {
481     const Location loc = op->getLoc();
482     const auto srcTp = getSparseTensorType(op.getInputCoo());
483     const auto dstTp = getSparseTensorType(op);
484 
485     const Value src = adaptor.getInputCoo();
486 
487     NewCallParams params(rewriter, loc);
488     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
489     rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizes)
490                                .genNewCall(Action::kSortCOOInPlace, src));
491 
492     return success();
493   }
494 };
495 
496 /// Sparse conversion rule for the dealloc operator.
497 class SparseTensorDeallocConverter
498     : public OpConversionPattern<bufferization::DeallocTensorOp> {
499 public:
500   using OpConversionPattern::OpConversionPattern;
501   LogicalResult
502   matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
503                   ConversionPatternRewriter &rewriter) const override {
504     if (!getSparseTensorType(op.getTensor()).hasEncoding())
505       return failure();
506     StringRef name = "delSparseTensor";
507     createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
508                    EmitCInterface::Off);
509     rewriter.eraseOp(op);
510     return success();
511   }
512 };
513 
514 /// Sparse conversion rule for position accesses.
515 class SparseTensorToPositionsConverter
516     : public OpConversionPattern<ToPositionsOp> {
517 public:
518   using OpConversionPattern::OpConversionPattern;
519   LogicalResult
520   matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
521                   ConversionPatternRewriter &rewriter) const override {
522     Type resTp = op.getType();
523     Type posTp = cast<ShapedType>(resTp).getElementType();
524     SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
525     Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel());
526     replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl},
527                           EmitCInterface::On);
528     return success();
529   }
530 };
531 
532 /// Sparse conversion rule for coordinate accesses.
533 class SparseTensorToCoordinatesConverter
534     : public OpConversionPattern<ToCoordinatesOp> {
535 public:
536   using OpConversionPattern::OpConversionPattern;
537   LogicalResult
538   matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
539                   ConversionPatternRewriter &rewriter) const override {
540     // TODO: use `SparseTensorType::getCrdType` instead.
541     Type resType = op.getType();
542     const Type crdTp = cast<ShapedType>(resType).getElementType();
543     SmallString<19> name{"sparseCoordinates",
544                          overheadTypeFunctionSuffix(crdTp)};
545     Location loc = op->getLoc();
546     Value lvl = constantIndex(rewriter, loc, op.getLevel());
547 
548     // The function returns a MemRef without a layout.
549     MemRefType callRetType = get1DMemRefType(crdTp, false);
550     SmallVector<Value> operands{adaptor.getTensor(), lvl};
551     auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType,
552                       operands, EmitCInterface::On);
553     Value callRet =
554         rewriter.create<func::CallOp>(loc, callRetType, fn, operands)
555             .getResult(0);
556 
557     // Cast the MemRef type to the type expected by the users, though these
558     // two types should be compatible at runtime.
559     if (resType != callRetType)
560       callRet = rewriter.create<memref::CastOp>(loc, resType, callRet);
561     rewriter.replaceOp(op, callRet);
562 
563     return success();
564   }
565 };
566 
567 /// Sparse conversion rule for value accesses.
568 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
569 public:
570   using OpConversionPattern::OpConversionPattern;
571   LogicalResult
572   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
573                   ConversionPatternRewriter &rewriter) const override {
574     auto resType = cast<ShapedType>(op.getType());
575     rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
576                                          adaptor.getOperands()));
577     return success();
578   }
579 };
580 
581 /// Sparse conversion rule for number of entries operator.
582 class SparseNumberOfEntriesConverter
583     : public OpConversionPattern<NumberOfEntriesOp> {
584 public:
585   using OpConversionPattern::OpConversionPattern;
586   LogicalResult
587   matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
588                   ConversionPatternRewriter &rewriter) const override {
589     Location loc = op.getLoc();
590     // Query values array size for the actually stored values size.
591     Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType();
592     auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType);
593     Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
594     rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
595                                                constantIndex(rewriter, loc, 0));
596     return success();
597   }
598 };
599 
600 /// Sparse conversion rule for tensor rematerialization.
601 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
602 public:
603   using OpConversionPattern::OpConversionPattern;
604   LogicalResult
605   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
606                   ConversionPatternRewriter &rewriter) const override {
607     if (op.getHasInserts()) {
608       // Finalize any pending insertions.
609       StringRef name = "endInsert";
610       createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
611                      EmitCInterface::Off);
612     }
613     rewriter.replaceOp(op, adaptor.getOperands());
614     return success();
615   }
616 };
617 
618 /// Sparse conversion rule for the insertion operator.
619 class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
620 public:
621   using OpConversionPattern::OpConversionPattern;
622   LogicalResult
623   matchAndRewrite(InsertOp op, OpAdaptor adaptor,
624                   ConversionPatternRewriter &rewriter) const override {
625     // Note that the current regime only allows for strict lexicographic
626     // coordinate order. All values are passed by reference through stack
627     // allocated memrefs.
628     Location loc = op->getLoc();
629     const auto stt = getSparseTensorType(op.getTensor());
630     const auto elemTp = stt.getElementType();
631     const Level lvlRank = stt.getLvlRank();
632     auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
633     auto vref = genAllocaScalar(rewriter, loc, elemTp);
634     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
635     rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref);
636     SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
637     createFuncCall(rewriter, loc, name, {},
638                    {adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On);
639     rewriter.replaceOp(op, adaptor.getTensor());
640     return success();
641   }
642 };
643 
644 /// Sparse conversion rule for the expand operator.
645 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
646 public:
647   using OpConversionPattern::OpConversionPattern;
648   LogicalResult
649   matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
650                   ConversionPatternRewriter &rewriter) const override {
651     Location loc = op->getLoc();
652     const auto srcTp = getSparseTensorType(op.getTensor());
653     Type eltType = srcTp.getElementType();
654     Type boolType = rewriter.getIntegerType(1);
655     Type idxType = rewriter.getIndexType();
656     // All initialization should be done on entry of the loop nest.
657     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
658     // Get the cardinality of valid coordinates for the innermost level.
659     Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
660                                    srcTp.getLvlRank() - 1);
661     // Allocate temporary buffers for values, filled-switch, and coordinates.
662     // We do not use stack buffers for this, since the expanded size may
663     // be rather large (as it envelops a single expanded dense dimension).
664     Value values = genAlloc(rewriter, loc, sz, eltType);
665     Value filled = genAlloc(rewriter, loc, sz, boolType);
666     Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
667     Value zero = constantZero(rewriter, loc, idxType);
668     // Reset the values/filled-switch to all-zero/false. Note that this
669     // introduces an O(N) operation into the computation, but this reset
670     // operation is amortized over the innermost loops for the access
671     // pattern expansion. As noted in the operation doc, we would like
672     // to amortize this setup cost even between kernels.
673     rewriter.create<linalg::FillOp>(
674         loc, ValueRange{constantZero(rewriter, loc, eltType)},
675         ValueRange{values});
676     rewriter.create<linalg::FillOp>(
677         loc, ValueRange{constantZero(rewriter, loc, boolType)},
678         ValueRange{filled});
679     // Replace expansion op with these buffers and initial coordinate.
680     assert(op.getNumResults() == 4);
681     rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
682     return success();
683   }
684 };
685 
686 /// Sparse conversion rule for the compress operator.
687 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
688 public:
689   using OpConversionPattern::OpConversionPattern;
690   LogicalResult
691   matchAndRewrite(CompressOp op, OpAdaptor adaptor,
692                   ConversionPatternRewriter &rewriter) const override {
693     Location loc = op->getLoc();
694     // Note that this method call resets the values/filled-switch back to
695     // all-zero/false by only iterating over the set elements, so the
696     // complexity remains proportional to the sparsity of the expanded
697     // access pattern.
698     Value values = adaptor.getValues();
699     Value filled = adaptor.getFilled();
700     Value added = adaptor.getAdded();
701     Value count = adaptor.getCount();
702     Value tensor = adaptor.getTensor();
703     const auto stt = getSparseTensorType(op.getTensor());
704     const Type elemTp = stt.getElementType();
705     const Level lvlRank = stt.getLvlRank();
706     auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
707     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
708     SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
709     createFuncCall(rewriter, loc, name, {},
710                    {tensor, lvlCoords, values, filled, added, count},
711                    EmitCInterface::On);
712     rewriter.replaceOp(op, adaptor.getTensor());
713     // Deallocate the buffers on exit of the loop nest.
714     Operation *parent = getTop(op);
715     rewriter.setInsertionPointAfter(parent);
716     rewriter.create<memref::DeallocOp>(loc, values);
717     rewriter.create<memref::DeallocOp>(loc, filled);
718     rewriter.create<memref::DeallocOp>(loc, added);
719     return success();
720   }
721 };
722 
723 /// Sparse conversion rule for the output operator.
724 class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
725 public:
726   using OpConversionPattern::OpConversionPattern;
727   LogicalResult
728   matchAndRewrite(OutOp op, OpAdaptor adaptor,
729                   ConversionPatternRewriter &rewriter) const override {
730     const Location loc = op->getLoc();
731     const auto srcTp = getSparseTensorType(op.getTensor());
732     // Convert to default permuted COO.
733     Value src = adaptor.getOperands()[0];
734     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
735     Value coo = NewCallParams(rewriter, loc)
736                     .genBuffers(srcTp.withoutDimToLvl(), dimSizes)
737                     .genNewCall(Action::kToCOO, src);
738     // Then output the tensor to external file with coordinates in the
739     // externally visible lexicographic coordinate order.  A sort is
740     // required if the source was not in that order yet (note that the
741     // sort can be dropped altogether if external format does not care
742     // about the order at all, but here we assume it does).
743     const Value sort = constantI1(rewriter, loc, !srcTp.isIdentity());
744     SmallVector<Value, 3> outParams{coo, adaptor.getOperands()[1], sort};
745     const Type elemTp = srcTp.getElementType();
746     SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(elemTp)};
747     createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off);
748     genDelCOOCall(rewriter, loc, elemTp, coo);
749     rewriter.eraseOp(op);
750     return success();
751   }
752 };
753 
754 /// Sparse conversion rule for the sparse_tensor.pack operator.
755 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
756 public:
757   using OpConversionPattern::OpConversionPattern;
758   LogicalResult
759   matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
760                   ConversionPatternRewriter &rewriter) const override {
761     const Location loc = op->getLoc();
762     const auto dstTp = getSparseTensorType(op.getResult());
763     // AssembleOps always returns a static shaped tensor result.
764     assert(dstTp.hasStaticDimShape());
765     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
766     Value dst =
767         NewCallParams(rewriter, loc)
768             .genBuffers(dstTp.withoutDimToLvl(), dimSizes)
769             .genNewCall(Action::kPack,
770                         genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
771                                           adaptor.getValues()));
772     rewriter.replaceOp(op, dst);
773     return success();
774   }
775 };
776 
777 } // namespace
778 
779 //===----------------------------------------------------------------------===//
780 // Sparse tensor type conversion into opaque pointer.
781 //===----------------------------------------------------------------------===//
782 
783 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
784   addConversion([](Type type) { return type; });
785   addConversion(convertSparseTensorTypes);
786 }
787 
788 //===----------------------------------------------------------------------===//
789 // Public method for populating conversion rules.
790 //===----------------------------------------------------------------------===//
791 
792 /// Populates the given patterns list with conversion rules required for
793 /// the sparsification of linear algebra operations.
794 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
795                                                   RewritePatternSet &patterns) {
796   patterns
797       .add<SparseReturnConverter, SparseTensorToDimSizeConverter,
798            SparseCastConverter, SparseTensorNewConverter,
799            SparseTensorAllocConverter, SparseTensorEmptyConverter,
800            SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
801            SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
802            SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
803            SparseTensorLoadConverter, SparseTensorInsertConverter,
804            SparseTensorExpandConverter, SparseTensorCompressConverter,
805            SparseTensorOutConverter, SparseTensorAssembleConverter>(
806           typeConverter, patterns.getContext());
807 }
808