xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (revision 160d483b1f3bda51a7ef28fae85920edcc43298e)
1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor primitives into calls into a runtime
10 // support library. Sparse tensor types are converted into opaque pointers
11 // to the underlying sparse storage schemes. The use of opaque pointers
12 // together with runtime support library keeps the conversion relatively
13 // simple, but at the expense of IR opacity, which obscures opportunities
14 // for subsequent optimization of the IR. An alternative is provided by
15 // the SparseTensorCodegen pass.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "CodegenUtils.h"
20 
21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
30 #include "mlir/Dialect/Tensor/IR/Tensor.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 
33 using namespace mlir;
34 using namespace mlir::sparse_tensor;
35 
36 namespace {
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Maps each sparse tensor type to an opaque pointer.
43 static std::optional<Type> convertSparseTensorTypes(Type type) {
44   if (getSparseTensorEncoding(type) != nullptr)
45     return LLVM::LLVMPointerType::get(type.getContext());
46   return std::nullopt;
47 }
48 
49 /// Replaces the `op` with a `CallOp` to the `getFunc()` function reference.
50 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
51                                           StringRef name, TypeRange resultType,
52                                           ValueRange operands,
53                                           EmitCInterface emitCInterface) {
54   auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands,
55                     emitCInterface);
56   return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
57                                                    operands);
58 }
59 
60 /// Generates call to lookup a level-size.  N.B., this only generates
61 /// the raw function call, and therefore (intentionally) does not perform
62 /// any dim<->lvl conversion or other logic.
63 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
64                             uint64_t lvl) {
65   StringRef name = "sparseLvlSize";
66   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
67   Type iTp = builder.getIndexType();
68   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
69       .getResult(0);
70 }
71 
72 /// Generates call to lookup a dimension-size.  N.B., this only generates
73 /// the raw function call, and therefore (intentionally) does not perform
74 /// any dim<->lvl conversion or other logic.
75 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
76                             uint64_t dim) {
77   StringRef name = "sparseDimSize";
78   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
79   Type iTp = builder.getIndexType();
80   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
81       .getResult(0);
82 }
83 
84 /// Looks up a level-size by returning a statically-computed constant
85 /// (when possible), or by calling `genLvlSizeCall` (when dynamic).
86 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
87                                  SparseTensorType stt, Value tensor,
88                                  Level lvl) {
89   // Only sparse tensors have "levels" to query.
90   assert(stt.hasEncoding());
91   // TODO: The following implementation only handles permutations;
92   // we'll need to generalize this to handle arbitrary AffineExpr.
93   //
94   // There's no need to assert `isPermutation` here: because
95   // `getDimPosition` checks that the expr isa `AffineDimExpr`,
96   // which is all we care about (for supporting permutations).
97   const Dimension dim =
98       stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
99   const Size sz = stt.getDynamicDimSize(dim);
100   if (!ShapedType::isDynamic(sz))
101     return constantIndex(builder, loc, sz);
102   // If we cannot statically compute the size from the shape, then we
103   // must dynamically query it.  (In principle we could also dynamically
104   // compute it, but since we already did so to construct the `tensor`
105   // in the first place, we might as well query rather than recompute.)
106   return genLvlSizeCall(builder, loc, tensor, lvl);
107 }
108 
109 /// Looks up a dimension-size by returning a constant from the shape
110 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
111 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
112 /// of dense tensors).
113 static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
114                                  SparseTensorType stt, Value tensor,
115                                  Dimension dim) {
116   const Size sz = stt.getDynamicDimSize(dim);
117   if (!ShapedType::isDynamic(sz))
118     return constantIndex(builder, loc, sz);
119   if (stt.hasEncoding())
120     return genDimSizeCall(builder, loc, tensor, dim);
121   return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
122 }
123 
124 /// Populates the array with the dimension-sizes of the given tensor.
125 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
126                          Value tensor, SmallVectorImpl<Value> &out) {
127   const Dimension dimRank = stt.getDimRank();
128   out.clear();
129   out.reserve(dimRank);
130   for (Dimension d = 0; d < dimRank; d++)
131     out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
132 }
133 
134 /// Returns an array with the dimension-sizes of the given tensor.
135 /// If the *tensor* parameters is null, the tensor type is assumed to have a
136 /// static shape.
137 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
138                                       SparseTensorType stt,
139                                       Value tensor = Value()) {
140   SmallVector<Value> out;
141   fillDimSizes(builder, loc, stt, tensor, out);
142   return out;
143 }
144 
145 /// Generates an uninitialized buffer of the given size and type,
146 /// but returns it as type `memref<? x $tp>` (rather than as type
147 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
148 /// this buffer must be explicitly deallocated by client.
149 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
150   auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
151   return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
152 }
153 
154 /// Generates a temporary buffer for the level-types of the given encoding.
155 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
156                                SparseTensorType stt) {
157   SmallVector<Value> lvlTypes;
158   lvlTypes.reserve(stt.getLvlRank());
159   for (const auto dlt : stt.getEncoding().getLvlTypes())
160     lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
161   return allocaBuffer(builder, loc, lvlTypes);
162 }
163 
164 /// Extracts the bare (aligned) pointers that point to the tensor.
165 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
166                                       Value tensor) {
167   auto buf = genToMemref(builder, loc, tensor);
168   return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf);
169 }
170 
171 /// Generates a temporary buffer for the level-types of the given encoding.
172 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
173                                ValueRange lvlTensors, Value valTensor) {
174   SmallVector<Value> lvlBarePtrs;
175   lvlBarePtrs.reserve(lvlTensors.size() + 1);
176   // Passing in lvl buffer pointers.
177   for (const auto lvl : lvlTensors)
178     lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
179 
180   // Passing in value buffer pointers.
181   lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
182   Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>(
183       loc, allocaBuffer(builder, loc, lvlBarePtrs));
184   Value idxCast =
185       builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr);
186   return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder),
187                                           idxCast);
188 }
189 
190 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
191 /// the "swiss army knife" method of the sparse runtime support library
192 /// for materializing sparse tensors into the computation. This abstraction
193 /// reduces the need for modifications when the API changes.
194 class NewCallParams final {
195 public:
196   /// Allocates the `ValueRange` for the `func::CallOp` parameters.
197   NewCallParams(OpBuilder &builder, Location loc)
198       : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
199 
200   /// Initializes all static parameters (i.e., those which indicate
201   /// type-level information such as the encoding and sizes), generating
202   /// MLIR buffers as needed, and returning `this` for method chaining.
203   NewCallParams &genBuffers(SparseTensorType stt,
204                             ArrayRef<Value> dimSizesValues,
205                             Value dimSizesBuffer = Value()) {
206     assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
207     // Sparsity annotations.
208     params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
209     // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
210     params[kParamDimSizes] = dimSizesBuffer
211                                  ? dimSizesBuffer
212                                  : allocaBuffer(builder, loc, dimSizesValues);
213     params[kParamLvlSizes] =
214         genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes],
215                       params[kParamDim2Lvl], params[kParamLvl2Dim]);
216     // Secondary and primary types encoding.
217     const auto enc = stt.getEncoding();
218     params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
219     params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc);
220     params[kParamValTp] =
221         constantPrimaryTypeEncoding(builder, loc, stt.getElementType());
222     // Return `this` for method chaining.
223     return *this;
224   }
225 
226   /// Checks whether all the static parameters have been initialized.
227   bool isInitialized() const {
228     for (unsigned i = 0; i < kNumStaticParams; ++i)
229       if (!params[i])
230         return false;
231     return true;
232   }
233 
234   /// Generates a function call, with the current static parameters
235   /// and the given dynamic arguments.
236   Value genNewCall(Action action, Value ptr = Value()) {
237     assert(isInitialized() && "Must initialize before genNewCall");
238     StringRef name = "newSparseTensor";
239     params[kParamAction] = constantAction(builder, loc, action);
240     params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp);
241     return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
242         .getResult(0);
243   }
244 
245 private:
246   static constexpr unsigned kNumStaticParams = 8;
247   static constexpr unsigned kNumDynamicParams = 2;
248   static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
249   static constexpr unsigned kParamDimSizes = 0;
250   static constexpr unsigned kParamLvlSizes = 1;
251   static constexpr unsigned kParamLvlTypes = 2;
252   static constexpr unsigned kParamDim2Lvl = 3;
253   static constexpr unsigned kParamLvl2Dim = 4;
254   static constexpr unsigned kParamPosTp = 5;
255   static constexpr unsigned kParamCrdTp = 6;
256   static constexpr unsigned kParamValTp = 7;
257   static constexpr unsigned kParamAction = 8;
258   static constexpr unsigned kParamPtr = 9;
259 
260   OpBuilder &builder;
261   Location loc;
262   Type pTp;
263   Value params[kNumParams];
264 };
265 
266 /// Generates a call to obtain the values array.
267 static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
268                            ValueRange ptr) {
269   SmallString<15> name{"sparseValues",
270                        primaryTypeFunctionSuffix(tp.getElementType())};
271   return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
272       .getResult(0);
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // Conversion rules.
277 //===----------------------------------------------------------------------===//
278 
279 /// Sparse conversion rule for returns.
280 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
281 public:
282   using OpConversionPattern::OpConversionPattern;
283   LogicalResult
284   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
285                   ConversionPatternRewriter &rewriter) const override {
286     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
287     return success();
288   }
289 };
290 
291 /// Sparse conversion rule for accessing level-sizes.
292 class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> {
293 public:
294   using OpConversionPattern::OpConversionPattern;
295   LogicalResult
296   matchAndRewrite(LvlOp op, OpAdaptor adaptor,
297                   ConversionPatternRewriter &rewriter) const override {
298     const auto stt = getSparseTensorType(op.getSource());
299     // Only rewrite sparse DimOp.
300     if (!stt.hasEncoding())
301       return failure();
302 
303     // Only rewrite DimOp with constant index.
304     std::optional<int64_t> lvl = op.getConstantLvlIndex();
305 
306     if (!lvl)
307       return failure();
308 
309     // By now, if the level size is constant, the operation should have already
310     // been folded by LvlOp's folder, so we generate the call unconditionally.
311     Value src = adaptor.getOperands()[0];
312     rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
313     return success();
314   }
315 };
316 
317 /// Sparse conversion rule for trivial tensor casts.
318 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
319 public:
320   using OpConversionPattern::OpConversionPattern;
321   LogicalResult
322   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
323                   ConversionPatternRewriter &rewriter) const override {
324     // Only rewrite identically annotated source/dest.
325     auto encDst = getSparseTensorEncoding(op.getType());
326     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
327     if (!encDst || encDst != encSrc)
328       return failure();
329     rewriter.replaceOp(op, adaptor.getOperands());
330     return success();
331   }
332 };
333 
334 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
335 public:
336   using OpConversionPattern::OpConversionPattern;
337   LogicalResult
338   matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
339                   ConversionPatternRewriter &rewriter) const override {
340     // Simply fold the operation.
341     rewriter.replaceOp(op, adaptor.getSource());
342     return success();
343   }
344 };
345 
346 /// Sparse conversion rule for the new operator.
347 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
348 public:
349   using OpConversionPattern::OpConversionPattern;
350   LogicalResult
351   matchAndRewrite(NewOp op, OpAdaptor adaptor,
352                   ConversionPatternRewriter &rewriter) const override {
353     Location loc = op.getLoc();
354     const auto stt = getSparseTensorType(op);
355     if (!stt.hasEncoding())
356       return failure();
357     // Construct the `reader` opening method calls.
358     SmallVector<Value> dimShapesValues;
359     Value dimSizesBuffer;
360     Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
361                              dimShapesValues, dimSizesBuffer);
362     // Use the `reader` to parse the file.
363     Value tensor = NewCallParams(rewriter, loc)
364                        .genBuffers(stt, dimShapesValues, dimSizesBuffer)
365                        .genNewCall(Action::kFromReader, reader);
366     // Free the memory for `reader`.
367     createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
368                    EmitCInterface::Off);
369     rewriter.replaceOp(op, tensor);
370     return success();
371   }
372 };
373 
374 /// Sparse conversion rule for the alloc operator.
375 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
376 class SparseTensorAllocConverter
377     : public OpConversionPattern<bufferization::AllocTensorOp> {
378 public:
379   using OpConversionPattern::OpConversionPattern;
380   LogicalResult
381   matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
382                   ConversionPatternRewriter &rewriter) const override {
383     const auto stt = getSparseTensorType(op);
384     if (!stt.hasEncoding())
385       return failure();
386     if (op.getCopy())
387       return rewriter.notifyMatchFailure(op, "alloc copy not implemented");
388     // Gather all dimension sizes as SSA values.
389     Location loc = op.getLoc();
390     const Dimension dimRank = stt.getDimRank();
391     SmallVector<Value> dimSizes;
392     dimSizes.reserve(dimRank);
393     unsigned operandCtr = 0;
394     for (Dimension d = 0; d < dimRank; ++d) {
395       dimSizes.push_back(
396           stt.isDynamicDim(d)
397               ? adaptor.getOperands()[operandCtr++]
398               : constantIndex(rewriter, loc, op.getStaticSize(d)));
399     }
400     // Generate the call to construct empty tensor. The sizes are
401     // explicitly defined by the arguments to the alloc operator.
402     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
403                                .genBuffers(stt, dimSizes)
404                                .genNewCall(Action::kEmpty));
405     return success();
406   }
407 };
408 
409 /// Sparse conversion rule for the empty tensor.
410 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
411 public:
412   using OpConversionPattern::OpConversionPattern;
413   LogicalResult
414   matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
415                   ConversionPatternRewriter &rewriter) const override {
416     Location loc = op.getLoc();
417     const auto stt = getSparseTensorType(op);
418     if (!stt.hasEncoding())
419       return failure();
420     // Gather all dimension sizes as SSA values.
421     const Dimension dimRank = stt.getDimRank();
422     SmallVector<Value> dimSizes;
423     dimSizes.reserve(dimRank);
424     auto shape = op.getType().getShape();
425     unsigned operandCtr = 0;
426     for (Dimension d = 0; d < dimRank; ++d) {
427       dimSizes.push_back(stt.isDynamicDim(d)
428                              ? adaptor.getOperands()[operandCtr++]
429                              : constantIndex(rewriter, loc, shape[d]));
430     }
431     // Generate the call to construct empty tensor. The sizes are
432     // explicitly defined by the arguments to the alloc operator.
433     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
434                                .genBuffers(stt, dimSizes)
435                                .genNewCall(Action::kEmpty));
436     return success();
437   }
438 };
439 
440 /// Sparse conversion rule for the convert operator.
441 class SparseTensorReorderCOOConverter
442     : public OpConversionPattern<ReorderCOOOp> {
443 public:
444   using OpConversionPattern::OpConversionPattern;
445 
446   LogicalResult
447   matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
448                   ConversionPatternRewriter &rewriter) const override {
449     const Location loc = op->getLoc();
450     const auto srcTp = getSparseTensorType(op.getInputCoo());
451     const auto dstTp = getSparseTensorType(op);
452 
453     const Value src = adaptor.getInputCoo();
454 
455     NewCallParams params(rewriter, loc);
456     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
457     rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizes)
458                                .genNewCall(Action::kSortCOOInPlace, src));
459 
460     return success();
461   }
462 };
463 
464 /// Sparse conversion rule for the dealloc operator.
465 class SparseTensorDeallocConverter
466     : public OpConversionPattern<bufferization::DeallocTensorOp> {
467 public:
468   using OpConversionPattern::OpConversionPattern;
469   LogicalResult
470   matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
471                   ConversionPatternRewriter &rewriter) const override {
472     if (!getSparseTensorType(op.getTensor()).hasEncoding())
473       return failure();
474     StringRef name = "delSparseTensor";
475     createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
476                    EmitCInterface::Off);
477     rewriter.eraseOp(op);
478     return success();
479   }
480 };
481 
482 /// Sparse conversion rule for position accesses.
483 class SparseTensorToPositionsConverter
484     : public OpConversionPattern<ToPositionsOp> {
485 public:
486   using OpConversionPattern::OpConversionPattern;
487   LogicalResult
488   matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
489                   ConversionPatternRewriter &rewriter) const override {
490     Type resTp = op.getType();
491     Type posTp = cast<ShapedType>(resTp).getElementType();
492     SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
493     Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel());
494     replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl},
495                           EmitCInterface::On);
496     return success();
497   }
498 };
499 
500 /// Sparse conversion rule for coordinate accesses.
501 class SparseTensorToCoordinatesConverter
502     : public OpConversionPattern<ToCoordinatesOp> {
503 public:
504   using OpConversionPattern::OpConversionPattern;
505   LogicalResult
506   matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
507                   ConversionPatternRewriter &rewriter) const override {
508     // TODO: use `SparseTensorType::getCrdType` instead.
509     Type resType = op.getType();
510     const Type crdTp = cast<ShapedType>(resType).getElementType();
511     SmallString<19> name{"sparseCoordinates",
512                          overheadTypeFunctionSuffix(crdTp)};
513     Location loc = op->getLoc();
514     Value lvl = constantIndex(rewriter, loc, op.getLevel());
515 
516     // The function returns a MemRef without a layout.
517     MemRefType callRetType = get1DMemRefType(crdTp, false);
518     SmallVector<Value> operands{adaptor.getTensor(), lvl};
519     auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType,
520                       operands, EmitCInterface::On);
521     Value callRet =
522         rewriter.create<func::CallOp>(loc, callRetType, fn, operands)
523             .getResult(0);
524 
525     // Cast the MemRef type to the type expected by the users, though these
526     // two types should be compatible at runtime.
527     if (resType != callRetType)
528       callRet = rewriter.create<memref::CastOp>(loc, resType, callRet);
529     rewriter.replaceOp(op, callRet);
530 
531     return success();
532   }
533 };
534 
535 /// Sparse conversion rule for value accesses.
536 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
537 public:
538   using OpConversionPattern::OpConversionPattern;
539   LogicalResult
540   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
541                   ConversionPatternRewriter &rewriter) const override {
542     auto resType = cast<ShapedType>(op.getType());
543     rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
544                                          adaptor.getOperands()));
545     return success();
546   }
547 };
548 
549 /// Sparse conversion rule for number of entries operator.
550 class SparseNumberOfEntriesConverter
551     : public OpConversionPattern<NumberOfEntriesOp> {
552 public:
553   using OpConversionPattern::OpConversionPattern;
554   LogicalResult
555   matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
556                   ConversionPatternRewriter &rewriter) const override {
557     Location loc = op.getLoc();
558     // Query values array size for the actually stored values size.
559     Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType();
560     auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType);
561     Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
562     rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
563                                                constantIndex(rewriter, loc, 0));
564     return success();
565   }
566 };
567 
568 /// Sparse conversion rule for tensor rematerialization.
569 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
570 public:
571   using OpConversionPattern::OpConversionPattern;
572   LogicalResult
573   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
574                   ConversionPatternRewriter &rewriter) const override {
575     if (op.getHasInserts()) {
576       // Finalize any pending insertions.
577       StringRef name = "endLexInsert";
578       createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
579                      EmitCInterface::Off);
580     }
581     rewriter.replaceOp(op, adaptor.getOperands());
582     return success();
583   }
584 };
585 
586 /// Sparse conversion rule for the insertion operator.
587 class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
588 public:
589   using OpConversionPattern::OpConversionPattern;
590   LogicalResult
591   matchAndRewrite(InsertOp op, OpAdaptor adaptor,
592                   ConversionPatternRewriter &rewriter) const override {
593     // Note that the current regime only allows for strict lexicographic
594     // coordinate order. All values are passed by reference through stack
595     // allocated memrefs.
596     Location loc = op->getLoc();
597     const auto stt = getSparseTensorType(op.getTensor());
598     const auto elemTp = stt.getElementType();
599     const Level lvlRank = stt.getLvlRank();
600     Value lvlCoords, vref;
601     {
602       OpBuilder::InsertionGuard guard(rewriter);
603       Operation *loop = op;
604       // Finds the outermost loop.
605       while (auto l = loop->getParentOfType<LoopLikeOpInterface>())
606         loop = l;
607 
608       if (llvm::isa<LoopLikeOpInterface>(loop)) {
609         // Hoists alloca outside the loop to avoid stack overflow.
610         rewriter.setInsertionPoint(loop);
611       }
612       lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
613       vref = genAllocaScalar(rewriter, loc, elemTp);
614     }
615     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
616     rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref);
617     SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
618     createFuncCall(rewriter, loc, name, {},
619                    {adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On);
620     rewriter.replaceOp(op, adaptor.getTensor());
621     return success();
622   }
623 };
624 
625 /// Sparse conversion rule for the expand operator.
626 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
627 public:
628   using OpConversionPattern::OpConversionPattern;
629   LogicalResult
630   matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
631                   ConversionPatternRewriter &rewriter) const override {
632     Location loc = op->getLoc();
633     const auto srcTp = getSparseTensorType(op.getTensor());
634     Type eltType = srcTp.getElementType();
635     Type boolType = rewriter.getIntegerType(1);
636     Type idxType = rewriter.getIndexType();
637     // All initialization should be done on entry of the loop nest.
638     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
639     // Get the cardinality of valid coordinates for the innermost level.
640     Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
641                                    srcTp.getLvlRank() - 1);
642     // Allocate temporary buffers for values, filled-switch, and coordinates.
643     // We do not use stack buffers for this, since the expanded size may
644     // be rather large (as it envelops a single expanded dense dimension).
645     Value values = genAlloc(rewriter, loc, sz, eltType);
646     Value filled = genAlloc(rewriter, loc, sz, boolType);
647     Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
648     Value zero = constantZero(rewriter, loc, idxType);
649     // Reset the values/filled-switch to all-zero/false. Note that this
650     // introduces an O(N) operation into the computation, but this reset
651     // operation is amortized over the innermost loops for the access
652     // pattern expansion. As noted in the operation doc, we would like
653     // to amortize this setup cost even between kernels.
654     rewriter.create<linalg::FillOp>(
655         loc, ValueRange{constantZero(rewriter, loc, eltType)},
656         ValueRange{values});
657     rewriter.create<linalg::FillOp>(
658         loc, ValueRange{constantZero(rewriter, loc, boolType)},
659         ValueRange{filled});
660     // Replace expansion op with these buffers and initial coordinate.
661     assert(op.getNumResults() == 4);
662     rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
663     return success();
664   }
665 };
666 
667 /// Sparse conversion rule for the compress operator.
668 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
669 public:
670   using OpConversionPattern::OpConversionPattern;
671   LogicalResult
672   matchAndRewrite(CompressOp op, OpAdaptor adaptor,
673                   ConversionPatternRewriter &rewriter) const override {
674     Location loc = op->getLoc();
675     // Note that this method call resets the values/filled-switch back to
676     // all-zero/false by only iterating over the set elements, so the
677     // complexity remains proportional to the sparsity of the expanded
678     // access pattern.
679     Value values = adaptor.getValues();
680     Value filled = adaptor.getFilled();
681     Value added = adaptor.getAdded();
682     Value count = adaptor.getCount();
683     Value tensor = adaptor.getTensor();
684     const auto stt = getSparseTensorType(op.getTensor());
685     const Type elemTp = stt.getElementType();
686     const Level lvlRank = stt.getLvlRank();
687     auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
688     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
689     SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
690     createFuncCall(rewriter, loc, name, {},
691                    {tensor, lvlCoords, values, filled, added, count},
692                    EmitCInterface::On);
693     rewriter.replaceOp(op, adaptor.getTensor());
694     // Deallocate the buffers on exit of the loop nest.
695     Operation *parent = getTop(op);
696     rewriter.setInsertionPointAfter(parent);
697     rewriter.create<memref::DeallocOp>(loc, values);
698     rewriter.create<memref::DeallocOp>(loc, filled);
699     rewriter.create<memref::DeallocOp>(loc, added);
700     return success();
701   }
702 };
703 
704 /// Sparse conversion rule for the sparse_tensor.pack operator.
705 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
706 public:
707   using OpConversionPattern::OpConversionPattern;
708   LogicalResult
709   matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
710                   ConversionPatternRewriter &rewriter) const override {
711     const Location loc = op->getLoc();
712     const auto dstTp = getSparseTensorType(op.getResult());
713     // AssembleOps always returns a static shaped tensor result.
714     assert(dstTp.hasStaticDimShape());
715     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
716     Value dst =
717         NewCallParams(rewriter, loc)
718             .genBuffers(dstTp.withoutDimToLvl(), dimSizes)
719             .genNewCall(Action::kPack,
720                         genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
721                                           adaptor.getValues()));
722     rewriter.replaceOp(op, dst);
723     return success();
724   }
725 };
726 
727 } // namespace
728 
729 //===----------------------------------------------------------------------===//
730 // Sparse tensor type conversion into opaque pointer.
731 //===----------------------------------------------------------------------===//
732 
733 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
734   addConversion([](Type type) { return type; });
735   addConversion(convertSparseTensorTypes);
736 }
737 
738 //===----------------------------------------------------------------------===//
739 // Public method for populating conversion rules.
740 //===----------------------------------------------------------------------===//
741 
742 /// Populates the given patterns list with conversion rules required for
743 /// the sparsification of linear algebra operations.
744 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
745                                                   RewritePatternSet &patterns) {
746   patterns
747       .add<SparseReturnConverter, SparseTensorLvlOpConverter,
748            SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
749            SparseTensorAllocConverter, SparseTensorEmptyConverter,
750            SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
751            SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
752            SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
753            SparseTensorLoadConverter, SparseTensorInsertConverter,
754            SparseTensorExpandConverter, SparseTensorCompressConverter,
755            SparseTensorAssembleConverter>(typeConverter, patterns.getContext());
756 }
757