xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (revision fbe47bf532e83cd802bc452a0b7db9aef9fb2aad)
1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor primitives into calls into a runtime
10 // support library. Sparse tensor types are converted into opaque pointers
11 // to the underlying sparse storage schemes. The use of opaque pointers
12 // together with runtime support library keeps the conversion relatively
13 // simple, but at the expense of IR opacity, which obscures opportunities
14 // for subsequent optimization of the IR. An alternative is provided by
15 // the SparseTensorCodegen pass.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "CodegenUtils.h"
20 
21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
30 #include "mlir/Dialect/Tensor/IR/Tensor.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 
33 using namespace mlir;
34 using namespace mlir::sparse_tensor;
35 
36 namespace {
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Maps each sparse tensor type to an opaque pointer.
43 static std::optional<Type> convertSparseTensorTypes(Type type) {
44   if (getSparseTensorEncoding(type) != nullptr)
45     return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
46   return std::nullopt;
47 }
48 
49 /// Replaces the `op` with a `CallOp` to the `getFunc()` function reference.
50 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
51                                           StringRef name, TypeRange resultType,
52                                           ValueRange operands,
53                                           EmitCInterface emitCInterface) {
54   auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands,
55                     emitCInterface);
56   return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
57                                                    operands);
58 }
59 
60 /// Generates call to lookup a level-size.  N.B., this only generates
61 /// the raw function call, and therefore (intentionally) does not perform
62 /// any dim<->lvl conversion or other logic.
63 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
64                             uint64_t lvl) {
65   StringRef name = "sparseLvlSize";
66   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
67   Type iTp = builder.getIndexType();
68   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
69       .getResult(0);
70 }
71 
72 /// Generates call to lookup a dimension-size.  N.B., this only generates
73 /// the raw function call, and therefore (intentionally) does not perform
74 /// any dim<->lvl conversion or other logic.
75 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
76                             uint64_t dim) {
77   StringRef name = "sparseDimSize";
78   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
79   Type iTp = builder.getIndexType();
80   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
81       .getResult(0);
82 }
83 
84 /// Looks up a level-size by returning a statically-computed constant
85 /// (when possible), or by calling `genLvlSizeCall` (when dynamic).
86 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
87                                  SparseTensorType stt, Value tensor,
88                                  Level lvl) {
89   // Only sparse tensors have "levels" to query.
90   assert(stt.hasEncoding());
91   // TODO: The following implementation only handles permutations;
92   // we'll need to generalize this to handle arbitrary AffineExpr.
93   //
94   // There's no need to assert `isPermutation` here: because
95   // `getDimPosition` checks that the expr isa `AffineDimExpr`,
96   // which is all we care about (for supporting permutations).
97   const Dimension dim =
98       stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
99   if (const auto sz = stt.getStaticDimSize(dim))
100     return constantIndex(builder, loc, *sz);
101   // If we cannot statically compute the size from the shape, then we
102   // must dynamically query it.  (In principle we could also dynamically
103   // compute it, but since we already did so to construct the `tensor`
104   // in the first place, we might as well query rather than recompute.)
105   return genLvlSizeCall(builder, loc, tensor, lvl);
106 }
107 
108 /// Looks up a dimension-size by returning a constant from the shape
109 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
110 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
111 /// of dense tensors).
112 static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
113                                  SparseTensorType stt, Value tensor,
114                                  Dimension dim) {
115   if (const auto sz = stt.getStaticDimSize(dim))
116     return constantIndex(builder, loc, *sz);
117   if (stt.hasEncoding())
118     return genDimSizeCall(builder, loc, tensor, dim);
119   return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
120 }
121 
122 /// Populates the array with the dimension-sizes of the given tensor.
123 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
124                          Value tensor, SmallVectorImpl<Value> &out) {
125   const Dimension dimRank = stt.getDimRank();
126   out.clear();
127   out.reserve(dimRank);
128   for (Dimension d = 0; d < dimRank; d++)
129     out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
130 }
131 
132 /// Returns an array with the dimension-sizes of the given tensor.
133 /// If the *tensor* parameters is null, the tensor type is assumed to have a
134 /// static shape.
135 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
136                                       SparseTensorType stt,
137                                       Value tensor = Value()) {
138   SmallVector<Value> out;
139   fillDimSizes(builder, loc, stt, tensor, out);
140   return out;
141 }
142 
143 /// Generates an uninitialized buffer of the given size and type,
144 /// but returns it as type `memref<? x $tp>` (rather than as type
145 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
146 /// this buffer must be explicitly deallocated by client.
147 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
148   auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
149   return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
150 }
151 
152 /// Generates a temporary buffer for the level-types of the given encoding.
153 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
154                                SparseTensorType stt) {
155   SmallVector<Value> lvlTypes;
156   lvlTypes.reserve(stt.getLvlRank());
157   for (const auto dlt : stt.getEncoding().getLvlTypes())
158     lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
159   return allocaBuffer(builder, loc, lvlTypes);
160 }
161 
162 /// Extracts the bare (aligned) pointers that point to the tensor.
163 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
164                                       Value tensor) {
165   auto buf = genToMemref(builder, loc, tensor);
166   return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf);
167 }
168 
169 /// Generates a temporary buffer for the level-types of the given encoding.
170 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
171                                ValueRange lvlTensors, Value valTensor) {
172   SmallVector<Value> lvlBarePtrs;
173   lvlBarePtrs.reserve(lvlTensors.size() + 1);
174   // Passing in lvl buffer pointers.
175   for (const auto lvl : lvlTensors)
176     lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
177 
178   // Passing in value buffer pointers.
179   lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
180   Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>(
181       loc, allocaBuffer(builder, loc, lvlBarePtrs));
182   Value idxCast =
183       builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr);
184   return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder),
185                                           idxCast);
186 }
187 
188 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
189 /// the "swiss army knife" method of the sparse runtime support library
190 /// for materializing sparse tensors into the computation. This abstraction
191 /// reduces the need for modifications when the API changes.
192 class NewCallParams final {
193 public:
194   /// Allocates the `ValueRange` for the `func::CallOp` parameters.
195   NewCallParams(OpBuilder &builder, Location loc)
196       : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
197 
198   /// Initializes all static parameters (i.e., those which indicate
199   /// type-level information such as the encoding and sizes), generating
200   /// MLIR buffers as needed, and returning `this` for method chaining.
201   NewCallParams &genBuffers(SparseTensorType stt,
202                             ArrayRef<Value> dimSizesValues) {
203     assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
204     // Sparsity annotations.
205     params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
206     // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
207     params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizesValues);
208     params[kParamLvlSizes] =
209         genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes],
210                       params[kParamDim2Lvl], params[kParamLvl2Dim]);
211     // Secondary and primary types encoding.
212     setTemplateTypes(stt);
213     // Finally, make note that initialization is complete.
214     assert(isInitialized() && "Initialization failed");
215     // And return `this` for method chaining.
216     return *this;
217   }
218 
219   /// (Re)sets the C++ template type parameters, and returns `this`
220   /// for method chaining. This is already done as part of `genBuffers`,
221   /// but is factored out so that it can also be called independently
222   /// whenever subsequent `genNewCall` calls want to reuse the same
223   /// buffers but different type parameters.
224   //
225   // TODO: This is only ever used by sparse2sparse-viaCOO `ConvertOp`;
226   // is there a better way to handle that than this one-off setter method?
227   NewCallParams &setTemplateTypes(SparseTensorType stt) {
228     const auto enc = stt.getEncoding();
229     params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
230     params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc);
231     params[kParamValTp] =
232         constantPrimaryTypeEncoding(builder, loc, stt.getElementType());
233     return *this;
234   }
235 
236   /// Checks whether all the static parameters have been initialized.
237   bool isInitialized() const {
238     for (unsigned i = 0; i < kNumStaticParams; ++i)
239       if (!params[i])
240         return false;
241     return true;
242   }
243 
244   /// Generates a function call, with the current static parameters
245   /// and the given dynamic arguments.
246   Value genNewCall(Action action, Value ptr = Value()) {
247     assert(isInitialized() && "Must initialize before genNewCall");
248     StringRef name = "newSparseTensor";
249     params[kParamAction] = constantAction(builder, loc, action);
250     params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp);
251     return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
252         .getResult(0);
253   }
254 
255 private:
256   static constexpr unsigned kNumStaticParams = 8;
257   static constexpr unsigned kNumDynamicParams = 2;
258   static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
259   static constexpr unsigned kParamDimSizes = 0;
260   static constexpr unsigned kParamLvlSizes = 1;
261   static constexpr unsigned kParamLvlTypes = 2;
262   static constexpr unsigned kParamDim2Lvl = 3;
263   static constexpr unsigned kParamLvl2Dim = 4;
264   static constexpr unsigned kParamPosTp = 5;
265   static constexpr unsigned kParamCrdTp = 6;
266   static constexpr unsigned kParamValTp = 7;
267   static constexpr unsigned kParamAction = 8;
268   static constexpr unsigned kParamPtr = 9;
269 
270   OpBuilder &builder;
271   Location loc;
272   Type pTp;
273   Value params[kNumParams];
274 };
275 
276 /// Generates a call to obtain the values array.
277 static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
278                            ValueRange ptr) {
279   SmallString<15> name{"sparseValues",
280                        primaryTypeFunctionSuffix(tp.getElementType())};
281   return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
282       .getResult(0);
283 }
284 
285 /// Generates a call to release/delete a `SparseTensorCOO`.
286 static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp,
287                           Value coo) {
288   SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)};
289   createFuncCall(builder, loc, name, {}, coo, EmitCInterface::Off);
290 }
291 
292 //===----------------------------------------------------------------------===//
293 // Conversion rules.
294 //===----------------------------------------------------------------------===//
295 
296 /// Sparse conversion rule for returns.
297 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
298 public:
299   using OpConversionPattern::OpConversionPattern;
300   LogicalResult
301   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
302                   ConversionPatternRewriter &rewriter) const override {
303     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
304     return success();
305   }
306 };
307 
308 /// Sparse conversion rule for accessing dimension-sizes.
309 class SparseTensorToDimSizeConverter
310     : public OpConversionPattern<tensor::DimOp> {
311 public:
312   using OpConversionPattern::OpConversionPattern;
313   LogicalResult
314   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
315                   ConversionPatternRewriter &rewriter) const override {
316     const auto stt = getSparseTensorType(op.getSource());
317     // Only rewrite sparse DimOp.
318     if (!stt.hasEncoding())
319       return failure();
320     // Only rewrite DimOp with constant index.
321     std::optional<int64_t> dim = op.getConstantIndex();
322     if (!dim)
323       return failure();
324     // Generate the call.
325     Value src = adaptor.getOperands()[0];
326     rewriter.replaceOp(
327         op, createOrFoldDimCall(rewriter, op->getLoc(), stt, src, *dim));
328     return success();
329   }
330 };
331 
332 /// Sparse conversion rule for trivial tensor casts.
333 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
334 public:
335   using OpConversionPattern::OpConversionPattern;
336   LogicalResult
337   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
338                   ConversionPatternRewriter &rewriter) const override {
339     // Only rewrite identically annotated source/dest.
340     auto encDst = getSparseTensorEncoding(op.getType());
341     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
342     if (!encDst || encDst != encSrc)
343       return failure();
344     rewriter.replaceOp(op, adaptor.getOperands());
345     return success();
346   }
347 };
348 
349 /// Sparse conversion rule for the new operator.
350 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
351 public:
352   using OpConversionPattern::OpConversionPattern;
353   LogicalResult
354   matchAndRewrite(NewOp op, OpAdaptor adaptor,
355                   ConversionPatternRewriter &rewriter) const override {
356     Location loc = op.getLoc();
357     const auto stt = getSparseTensorType(op);
358     if (!stt.hasEncoding())
359       return failure();
360     // Construct the reader opening method calls.
361     SmallVector<Value> dimShapesValues;
362     Value dimSizesBuffer;
363     Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
364                              dimShapesValues, dimSizesBuffer);
365     // Now construct the lvlSizes, dim2lvl, and lvl2dim buffers.
366     Value dim2lvlBuffer;
367     Value lvl2dimBuffer;
368     Value lvlSizesBuffer =
369         genMapBuffers(rewriter, loc, stt, dimShapesValues, dimSizesBuffer,
370                       dim2lvlBuffer, lvl2dimBuffer);
371     // Use the `reader` to parse the file.
372     Type opaqueTp = getOpaquePointerType(rewriter);
373     Type eltTp = stt.getElementType();
374     Value valTp = constantPrimaryTypeEncoding(rewriter, loc, eltTp);
375     SmallVector<Value, 8> params{
376         reader,
377         lvlSizesBuffer,
378         genLvlTypesBuffer(rewriter, loc, stt),
379         dim2lvlBuffer,
380         lvl2dimBuffer,
381         constantPosTypeEncoding(rewriter, loc, stt.getEncoding()),
382         constantCrdTypeEncoding(rewriter, loc, stt.getEncoding()),
383         valTp};
384     Value tensor = createFuncCall(rewriter, loc, "newSparseTensorFromReader",
385                                   opaqueTp, params, EmitCInterface::On)
386                        .getResult(0);
387     // Free the memory for `reader`.
388     createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
389                    EmitCInterface::Off);
390     rewriter.replaceOp(op, tensor);
391     return success();
392   }
393 };
394 
395 /// Sparse conversion rule for the alloc operator.
396 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
397 class SparseTensorAllocConverter
398     : public OpConversionPattern<bufferization::AllocTensorOp> {
399 public:
400   using OpConversionPattern::OpConversionPattern;
401   LogicalResult
402   matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
403                   ConversionPatternRewriter &rewriter) const override {
404     if (op.getCopy())
405       return rewriter.notifyMatchFailure(op,
406                                          "sparse tensor copy not implemented");
407     Location loc = op.getLoc();
408     const auto stt = getSparseTensorType(op);
409     if (!stt.hasEncoding())
410       return failure();
411     // Gather all dimension sizes as SSA values.
412     const Dimension dimRank = stt.getDimRank();
413     SmallVector<Value> dimSizes;
414     dimSizes.reserve(dimRank);
415     unsigned operandCtr = 0;
416     for (Dimension d = 0; d < dimRank; ++d) {
417       dimSizes.push_back(
418           stt.isDynamicDim(d)
419               ? adaptor.getOperands()[operandCtr++]
420               : constantIndex(rewriter, loc, op.getStaticSize(d)));
421     }
422     // Generate the call to construct empty tensor. The sizes are
423     // explicitly defined by the arguments to the alloc operator.
424     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
425                                .genBuffers(stt, dimSizes)
426                                .genNewCall(Action::kEmpty));
427     return success();
428   }
429 };
430 
431 /// Sparse conversion rule for the empty tensor.
432 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
433 public:
434   using OpConversionPattern::OpConversionPattern;
435   LogicalResult
436   matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
437                   ConversionPatternRewriter &rewriter) const override {
438     Location loc = op.getLoc();
439     const auto stt = getSparseTensorType(op);
440     if (!stt.hasEncoding())
441       return failure();
442     // Gather all dimension sizes as SSA values.
443     const Dimension dimRank = stt.getDimRank();
444     SmallVector<Value> dimSizes;
445     dimSizes.reserve(dimRank);
446     auto shape = op.getType().getShape();
447     unsigned operandCtr = 0;
448     for (Dimension d = 0; d < dimRank; ++d) {
449       dimSizes.push_back(stt.isDynamicDim(d)
450                              ? adaptor.getOperands()[operandCtr++]
451                              : constantIndex(rewriter, loc, shape[d]));
452     }
453     // Generate the call to construct empty tensor. The sizes are
454     // explicitly defined by the arguments to the alloc operator.
455     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
456                                .genBuffers(stt, dimSizes)
457                                .genNewCall(Action::kEmpty));
458     return success();
459   }
460 };
461 
462 /// Sparse conversion rule for the convert operator.
463 class SparseTensorReorderCOOConverter
464     : public OpConversionPattern<ReorderCOOOp> {
465 public:
466   using OpConversionPattern::OpConversionPattern;
467 
468   LogicalResult
469   matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
470                   ConversionPatternRewriter &rewriter) const override {
471     const Location loc = op->getLoc();
472     const auto srcTp = getSparseTensorType(op.getInputCoo());
473     const auto dstTp = getSparseTensorType(op);
474 
475     const Value src = adaptor.getInputCoo();
476 
477     NewCallParams params(rewriter, loc);
478     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
479     rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizes)
480                                .genNewCall(Action::kSortCOOInPlace, src));
481 
482     return success();
483   }
484 };
485 
486 /// Sparse conversion rule for the dealloc operator.
487 class SparseTensorDeallocConverter
488     : public OpConversionPattern<bufferization::DeallocTensorOp> {
489 public:
490   using OpConversionPattern::OpConversionPattern;
491   LogicalResult
492   matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
493                   ConversionPatternRewriter &rewriter) const override {
494     if (!getSparseTensorType(op.getTensor()).hasEncoding())
495       return failure();
496     StringRef name = "delSparseTensor";
497     createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
498                    EmitCInterface::Off);
499     rewriter.eraseOp(op);
500     return success();
501   }
502 };
503 
504 /// Sparse conversion rule for position accesses.
505 class SparseTensorToPositionsConverter
506     : public OpConversionPattern<ToPositionsOp> {
507 public:
508   using OpConversionPattern::OpConversionPattern;
509   LogicalResult
510   matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
511                   ConversionPatternRewriter &rewriter) const override {
512     Type resTp = op.getType();
513     Type posTp = cast<ShapedType>(resTp).getElementType();
514     SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
515     Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel());
516     replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl},
517                           EmitCInterface::On);
518     return success();
519   }
520 };
521 
522 /// Sparse conversion rule for coordinate accesses.
523 class SparseTensorToCoordinatesConverter
524     : public OpConversionPattern<ToCoordinatesOp> {
525 public:
526   using OpConversionPattern::OpConversionPattern;
527   LogicalResult
528   matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
529                   ConversionPatternRewriter &rewriter) const override {
530     // TODO: use `SparseTensorType::getCrdType` instead.
531     Type resType = op.getType();
532     const Type crdTp = cast<ShapedType>(resType).getElementType();
533     SmallString<19> name{"sparseCoordinates",
534                          overheadTypeFunctionSuffix(crdTp)};
535     Location loc = op->getLoc();
536     Value lvl = constantIndex(rewriter, loc, op.getLevel());
537 
538     // The function returns a MemRef without a layout.
539     MemRefType callRetType = get1DMemRefType(crdTp, false);
540     SmallVector<Value> operands{adaptor.getTensor(), lvl};
541     auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType,
542                       operands, EmitCInterface::On);
543     Value callRet =
544         rewriter.create<func::CallOp>(loc, callRetType, fn, operands)
545             .getResult(0);
546 
547     // Cast the MemRef type to the type expected by the users, though these
548     // two types should be compatible at runtime.
549     if (resType != callRetType)
550       callRet = rewriter.create<memref::CastOp>(loc, resType, callRet);
551     rewriter.replaceOp(op, callRet);
552 
553     return success();
554   }
555 };
556 
557 /// Sparse conversion rule for value accesses.
558 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
559 public:
560   using OpConversionPattern::OpConversionPattern;
561   LogicalResult
562   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
563                   ConversionPatternRewriter &rewriter) const override {
564     auto resType = cast<ShapedType>(op.getType());
565     rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
566                                          adaptor.getOperands()));
567     return success();
568   }
569 };
570 
571 /// Sparse conversion rule for number of entries operator.
572 class SparseNumberOfEntriesConverter
573     : public OpConversionPattern<NumberOfEntriesOp> {
574 public:
575   using OpConversionPattern::OpConversionPattern;
576   LogicalResult
577   matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
578                   ConversionPatternRewriter &rewriter) const override {
579     Location loc = op.getLoc();
580     // Query values array size for the actually stored values size.
581     Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType();
582     auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType);
583     Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
584     rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
585                                                constantIndex(rewriter, loc, 0));
586     return success();
587   }
588 };
589 
590 /// Sparse conversion rule for tensor rematerialization.
591 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
592 public:
593   using OpConversionPattern::OpConversionPattern;
594   LogicalResult
595   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
596                   ConversionPatternRewriter &rewriter) const override {
597     if (op.getHasInserts()) {
598       // Finalize any pending insertions.
599       StringRef name = "endInsert";
600       createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
601                      EmitCInterface::Off);
602     }
603     rewriter.replaceOp(op, adaptor.getOperands());
604     return success();
605   }
606 };
607 
608 /// Sparse conversion rule for the insertion operator.
609 class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
610 public:
611   using OpConversionPattern::OpConversionPattern;
612   LogicalResult
613   matchAndRewrite(InsertOp op, OpAdaptor adaptor,
614                   ConversionPatternRewriter &rewriter) const override {
615     // Note that the current regime only allows for strict lexicographic
616     // coordinate order. All values are passed by reference through stack
617     // allocated memrefs.
618     Location loc = op->getLoc();
619     const auto stt = getSparseTensorType(op.getTensor());
620     const auto elemTp = stt.getElementType();
621     const Level lvlRank = stt.getLvlRank();
622     auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
623     auto vref = genAllocaScalar(rewriter, loc, elemTp);
624     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
625     rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref);
626     SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
627     createFuncCall(rewriter, loc, name, {},
628                    {adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On);
629     rewriter.replaceOp(op, adaptor.getTensor());
630     return success();
631   }
632 };
633 
634 /// Sparse conversion rule for the expand operator.
635 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
636 public:
637   using OpConversionPattern::OpConversionPattern;
638   LogicalResult
639   matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
640                   ConversionPatternRewriter &rewriter) const override {
641     Location loc = op->getLoc();
642     const auto srcTp = getSparseTensorType(op.getTensor());
643     Type eltType = srcTp.getElementType();
644     Type boolType = rewriter.getIntegerType(1);
645     Type idxType = rewriter.getIndexType();
646     // All initialization should be done on entry of the loop nest.
647     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
648     // Get the cardinality of valid coordinates for the innermost level.
649     Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
650                                    srcTp.getLvlRank() - 1);
651     // Allocate temporary buffers for values, filled-switch, and coordinates.
652     // We do not use stack buffers for this, since the expanded size may
653     // be rather large (as it envelops a single expanded dense dimension).
654     Value values = genAlloc(rewriter, loc, sz, eltType);
655     Value filled = genAlloc(rewriter, loc, sz, boolType);
656     Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
657     Value zero = constantZero(rewriter, loc, idxType);
658     // Reset the values/filled-switch to all-zero/false. Note that this
659     // introduces an O(N) operation into the computation, but this reset
660     // operation is amortized over the innermost loops for the access
661     // pattern expansion. As noted in the operation doc, we would like
662     // to amortize this setup cost even between kernels.
663     rewriter.create<linalg::FillOp>(
664         loc, ValueRange{constantZero(rewriter, loc, eltType)},
665         ValueRange{values});
666     rewriter.create<linalg::FillOp>(
667         loc, ValueRange{constantZero(rewriter, loc, boolType)},
668         ValueRange{filled});
669     // Replace expansion op with these buffers and initial coordinate.
670     assert(op.getNumResults() == 4);
671     rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
672     return success();
673   }
674 };
675 
676 /// Sparse conversion rule for the compress operator.
677 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
678 public:
679   using OpConversionPattern::OpConversionPattern;
680   LogicalResult
681   matchAndRewrite(CompressOp op, OpAdaptor adaptor,
682                   ConversionPatternRewriter &rewriter) const override {
683     Location loc = op->getLoc();
684     // Note that this method call resets the values/filled-switch back to
685     // all-zero/false by only iterating over the set elements, so the
686     // complexity remains proportional to the sparsity of the expanded
687     // access pattern.
688     Value values = adaptor.getValues();
689     Value filled = adaptor.getFilled();
690     Value added = adaptor.getAdded();
691     Value count = adaptor.getCount();
692     Value tensor = adaptor.getTensor();
693     const auto stt = getSparseTensorType(op.getTensor());
694     const Type elemTp = stt.getElementType();
695     const Level lvlRank = stt.getLvlRank();
696     auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
697     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
698     SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
699     createFuncCall(rewriter, loc, name, {},
700                    {tensor, lvlCoords, values, filled, added, count},
701                    EmitCInterface::On);
702     rewriter.replaceOp(op, adaptor.getTensor());
703     // Deallocate the buffers on exit of the loop nest.
704     Operation *parent = getTop(op);
705     rewriter.setInsertionPointAfter(parent);
706     rewriter.create<memref::DeallocOp>(loc, values);
707     rewriter.create<memref::DeallocOp>(loc, filled);
708     rewriter.create<memref::DeallocOp>(loc, added);
709     return success();
710   }
711 };
712 
713 /// Sparse conversion rule for the output operator.
714 class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
715 public:
716   using OpConversionPattern::OpConversionPattern;
717   LogicalResult
718   matchAndRewrite(OutOp op, OpAdaptor adaptor,
719                   ConversionPatternRewriter &rewriter) const override {
720     const Location loc = op->getLoc();
721     const auto srcTp = getSparseTensorType(op.getTensor());
722     // Convert to default permuted COO.
723     Value src = adaptor.getOperands()[0];
724     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
725     Value coo = NewCallParams(rewriter, loc)
726                     .genBuffers(srcTp.withoutDimToLvl(), dimSizes)
727                     .genNewCall(Action::kToCOO, src);
728     // Then output the tensor to external file with coordinates in the
729     // externally visible lexicographic coordinate order.  A sort is
730     // required if the source was not in that order yet (note that the
731     // sort can be dropped altogether if external format does not care
732     // about the order at all, but here we assume it does).
733     const Value sort = constantI1(rewriter, loc, !srcTp.isIdentity());
734     SmallVector<Value, 3> outParams{coo, adaptor.getOperands()[1], sort};
735     const Type elemTp = srcTp.getElementType();
736     SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(elemTp)};
737     createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off);
738     genDelCOOCall(rewriter, loc, elemTp, coo);
739     rewriter.eraseOp(op);
740     return success();
741   }
742 };
743 
744 /// Sparse conversion rule for the sparse_tensor.pack operator.
745 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
746 public:
747   using OpConversionPattern::OpConversionPattern;
748   LogicalResult
749   matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
750                   ConversionPatternRewriter &rewriter) const override {
751     const Location loc = op->getLoc();
752     const auto dstTp = getSparseTensorType(op.getResult());
753     // AssembleOps always returns a static shaped tensor result.
754     assert(dstTp.hasStaticDimShape());
755     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
756     Value dst =
757         NewCallParams(rewriter, loc)
758             .genBuffers(dstTp.withoutDimToLvl(), dimSizes)
759             .genNewCall(Action::kPack,
760                         genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
761                                           adaptor.getValues()));
762     rewriter.replaceOp(op, dst);
763     return success();
764   }
765 };
766 
767 } // namespace
768 
769 //===----------------------------------------------------------------------===//
770 // Sparse tensor type conversion into opaque pointer.
771 //===----------------------------------------------------------------------===//
772 
773 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
774   addConversion([](Type type) { return type; });
775   addConversion(convertSparseTensorTypes);
776 }
777 
778 //===----------------------------------------------------------------------===//
779 // Public method for populating conversion rules.
780 //===----------------------------------------------------------------------===//
781 
782 /// Populates the given patterns list with conversion rules required for
783 /// the sparsification of linear algebra operations.
784 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
785                                                   RewritePatternSet &patterns) {
786   patterns
787       .add<SparseReturnConverter, SparseTensorToDimSizeConverter,
788            SparseCastConverter, SparseTensorNewConverter,
789            SparseTensorAllocConverter, SparseTensorEmptyConverter,
790            SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
791            SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
792            SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
793            SparseTensorLoadConverter, SparseTensorInsertConverter,
794            SparseTensorExpandConverter, SparseTensorCompressConverter,
795            SparseTensorOutConverter, SparseTensorAssembleConverter>(
796           typeConverter, patterns.getContext());
797 }
798