xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (revision dcae289d3a4f77b50efc8b8ecd2d5a58c86933ca)
1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor primitives into calls into a runtime
10 // support library. Sparse tensor types are converted into opaque pointers
11 // to the underlying sparse storage schemes. The use of opaque pointers
12 // together with runtime support library keeps the conversion relatively
13 // simple, but at the expense of IR opacity, which obscures opportunities
14 // for subsequent optimization of the IR. An alternative is provided by
15 // the SparseTensorCodegen pass.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "CodegenUtils.h"
20 
21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
30 #include "mlir/Dialect/Tensor/IR/Tensor.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 
33 using namespace mlir;
34 using namespace mlir::sparse_tensor;
35 
36 namespace {
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Maps each sparse tensor type to an opaque pointer.
43 static std::optional<Type> convertSparseTensorTypes(Type type) {
44   if (getSparseTensorEncoding(type) != nullptr)
45     return LLVM::LLVMPointerType::get(type.getContext());
46   return std::nullopt;
47 }
48 
49 /// Replaces the `op` with a `CallOp` to the `getFunc()` function reference.
50 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
51                                           StringRef name, TypeRange resultType,
52                                           ValueRange operands,
53                                           EmitCInterface emitCInterface) {
54   auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands,
55                     emitCInterface);
56   return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
57                                                    operands);
58 }
59 
60 /// Generates call to lookup a level-size.  N.B., this only generates
61 /// the raw function call, and therefore (intentionally) does not perform
62 /// any dim<->lvl conversion or other logic.
63 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
64                             uint64_t lvl) {
65   StringRef name = "sparseLvlSize";
66   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
67   Type iTp = builder.getIndexType();
68   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
69       .getResult(0);
70 }
71 
72 /// Generates call to lookup a dimension-size.  N.B., this only generates
73 /// the raw function call, and therefore (intentionally) does not perform
74 /// any dim<->lvl conversion or other logic.
75 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
76                             uint64_t dim) {
77   StringRef name = "sparseDimSize";
78   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
79   Type iTp = builder.getIndexType();
80   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
81       .getResult(0);
82 }
83 
84 /// Looks up a level-size by returning a statically-computed constant
85 /// (when possible), or by calling `genLvlSizeCall` (when dynamic).
86 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
87                                  SparseTensorType stt, Value tensor,
88                                  Level lvl) {
89   // Only sparse tensors have "levels" to query.
90   assert(stt.hasEncoding());
91   // TODO: The following implementation only handles permutations;
92   // we'll need to generalize this to handle arbitrary AffineExpr.
93   //
94   // There's no need to assert `isPermutation` here: because
95   // `getDimPosition` checks that the expr isa `AffineDimExpr`,
96   // which is all we care about (for supporting permutations).
97   const Dimension dim =
98       stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
99   if (const auto sz = stt.getStaticDimSize(dim))
100     return constantIndex(builder, loc, *sz);
101   // If we cannot statically compute the size from the shape, then we
102   // must dynamically query it.  (In principle we could also dynamically
103   // compute it, but since we already did so to construct the `tensor`
104   // in the first place, we might as well query rather than recompute.)
105   return genLvlSizeCall(builder, loc, tensor, lvl);
106 }
107 
108 /// Looks up a dimension-size by returning a constant from the shape
109 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
110 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
111 /// of dense tensors).
112 static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
113                                  SparseTensorType stt, Value tensor,
114                                  Dimension dim) {
115   if (const auto sz = stt.getStaticDimSize(dim))
116     return constantIndex(builder, loc, *sz);
117   if (stt.hasEncoding())
118     return genDimSizeCall(builder, loc, tensor, dim);
119   return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
120 }
121 
122 /// Populates the array with the dimension-sizes of the given tensor.
123 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
124                          Value tensor, SmallVectorImpl<Value> &out) {
125   const Dimension dimRank = stt.getDimRank();
126   out.clear();
127   out.reserve(dimRank);
128   for (Dimension d = 0; d < dimRank; d++)
129     out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
130 }
131 
132 /// Returns an array with the dimension-sizes of the given tensor.
133 /// If the *tensor* parameters is null, the tensor type is assumed to have a
134 /// static shape.
135 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
136                                       SparseTensorType stt,
137                                       Value tensor = Value()) {
138   SmallVector<Value> out;
139   fillDimSizes(builder, loc, stt, tensor, out);
140   return out;
141 }
142 
143 /// Generates an uninitialized buffer of the given size and type,
144 /// but returns it as type `memref<? x $tp>` (rather than as type
145 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
146 /// this buffer must be explicitly deallocated by client.
147 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
148   auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
149   return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
150 }
151 
152 /// Generates a temporary buffer for the level-types of the given encoding.
153 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
154                                SparseTensorType stt) {
155   SmallVector<Value> lvlTypes;
156   lvlTypes.reserve(stt.getLvlRank());
157   for (const auto dlt : stt.getEncoding().getLvlTypes())
158     lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
159   return allocaBuffer(builder, loc, lvlTypes);
160 }
161 
162 /// Extracts the bare (aligned) pointers that point to the tensor.
163 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
164                                       Value tensor) {
165   auto buf = genToMemref(builder, loc, tensor);
166   return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf);
167 }
168 
169 /// Generates a temporary buffer for the level-types of the given encoding.
170 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
171                                ValueRange lvlTensors, Value valTensor) {
172   SmallVector<Value> lvlBarePtrs;
173   lvlBarePtrs.reserve(lvlTensors.size() + 1);
174   // Passing in lvl buffer pointers.
175   for (const auto lvl : lvlTensors)
176     lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
177 
178   // Passing in value buffer pointers.
179   lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
180   Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>(
181       loc, allocaBuffer(builder, loc, lvlBarePtrs));
182   Value idxCast =
183       builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr);
184   return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder),
185                                           idxCast);
186 }
187 
188 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
189 /// the "swiss army knife" method of the sparse runtime support library
190 /// for materializing sparse tensors into the computation. This abstraction
191 /// reduces the need for modifications when the API changes.
192 class NewCallParams final {
193 public:
194   /// Allocates the `ValueRange` for the `func::CallOp` parameters.
195   NewCallParams(OpBuilder &builder, Location loc)
196       : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
197 
198   /// Initializes all static parameters (i.e., those which indicate
199   /// type-level information such as the encoding and sizes), generating
200   /// MLIR buffers as needed, and returning `this` for method chaining.
201   NewCallParams &genBuffers(SparseTensorType stt,
202                             ArrayRef<Value> dimSizesValues,
203                             Value dimSizesBuffer = Value()) {
204     assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
205     // Sparsity annotations.
206     params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
207     // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
208     params[kParamDimSizes] = dimSizesBuffer
209                                  ? dimSizesBuffer
210                                  : allocaBuffer(builder, loc, dimSizesValues);
211     params[kParamLvlSizes] =
212         genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes],
213                       params[kParamDim2Lvl], params[kParamLvl2Dim]);
214     // Secondary and primary types encoding.
215     const auto enc = stt.getEncoding();
216     params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
217     params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc);
218     params[kParamValTp] =
219         constantPrimaryTypeEncoding(builder, loc, stt.getElementType());
220     // Return `this` for method chaining.
221     return *this;
222   }
223 
224   /// Checks whether all the static parameters have been initialized.
225   bool isInitialized() const {
226     for (unsigned i = 0; i < kNumStaticParams; ++i)
227       if (!params[i])
228         return false;
229     return true;
230   }
231 
232   /// Generates a function call, with the current static parameters
233   /// and the given dynamic arguments.
234   Value genNewCall(Action action, Value ptr = Value()) {
235     assert(isInitialized() && "Must initialize before genNewCall");
236     StringRef name = "newSparseTensor";
237     params[kParamAction] = constantAction(builder, loc, action);
238     params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp);
239     return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
240         .getResult(0);
241   }
242 
243 private:
244   static constexpr unsigned kNumStaticParams = 8;
245   static constexpr unsigned kNumDynamicParams = 2;
246   static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
247   static constexpr unsigned kParamDimSizes = 0;
248   static constexpr unsigned kParamLvlSizes = 1;
249   static constexpr unsigned kParamLvlTypes = 2;
250   static constexpr unsigned kParamDim2Lvl = 3;
251   static constexpr unsigned kParamLvl2Dim = 4;
252   static constexpr unsigned kParamPosTp = 5;
253   static constexpr unsigned kParamCrdTp = 6;
254   static constexpr unsigned kParamValTp = 7;
255   static constexpr unsigned kParamAction = 8;
256   static constexpr unsigned kParamPtr = 9;
257 
258   OpBuilder &builder;
259   Location loc;
260   Type pTp;
261   Value params[kNumParams];
262 };
263 
264 /// Generates a call to obtain the values array.
265 static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
266                            ValueRange ptr) {
267   SmallString<15> name{"sparseValues",
268                        primaryTypeFunctionSuffix(tp.getElementType())};
269   return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
270       .getResult(0);
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // Conversion rules.
275 //===----------------------------------------------------------------------===//
276 
277 /// Sparse conversion rule for returns.
278 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
279 public:
280   using OpConversionPattern::OpConversionPattern;
281   LogicalResult
282   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
283                   ConversionPatternRewriter &rewriter) const override {
284     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
285     return success();
286   }
287 };
288 
289 /// Sparse conversion rule for accessing level-sizes.
290 class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> {
291 public:
292   using OpConversionPattern::OpConversionPattern;
293   LogicalResult
294   matchAndRewrite(LvlOp op, OpAdaptor adaptor,
295                   ConversionPatternRewriter &rewriter) const override {
296     const auto stt = getSparseTensorType(op.getSource());
297     // Only rewrite sparse DimOp.
298     if (!stt.hasEncoding())
299       return failure();
300 
301     // Only rewrite DimOp with constant index.
302     std::optional<int64_t> lvl = op.getConstantLvlIndex();
303 
304     if (!lvl)
305       return failure();
306 
307     // By now, if the level size is constant, the operation should have already
308     // been folded by LvlOp's folder, so we generate the call unconditionally.
309     Value src = adaptor.getOperands()[0];
310     rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
311     return success();
312   }
313 };
314 
315 /// Sparse conversion rule for trivial tensor casts.
316 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
317 public:
318   using OpConversionPattern::OpConversionPattern;
319   LogicalResult
320   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
321                   ConversionPatternRewriter &rewriter) const override {
322     // Only rewrite identically annotated source/dest.
323     auto encDst = getSparseTensorEncoding(op.getType());
324     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
325     if (!encDst || encDst != encSrc)
326       return failure();
327     rewriter.replaceOp(op, adaptor.getOperands());
328     return success();
329   }
330 };
331 
332 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
333 public:
334   using OpConversionPattern::OpConversionPattern;
335   LogicalResult
336   matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
337                   ConversionPatternRewriter &rewriter) const override {
338     // Simply fold the operation.
339     rewriter.replaceOp(op, adaptor.getSource());
340     return success();
341   }
342 };
343 
344 /// Sparse conversion rule for the new operator.
345 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
346 public:
347   using OpConversionPattern::OpConversionPattern;
348   LogicalResult
349   matchAndRewrite(NewOp op, OpAdaptor adaptor,
350                   ConversionPatternRewriter &rewriter) const override {
351     Location loc = op.getLoc();
352     const auto stt = getSparseTensorType(op);
353     if (!stt.hasEncoding())
354       return failure();
355     // Construct the `reader` opening method calls.
356     SmallVector<Value> dimShapesValues;
357     Value dimSizesBuffer;
358     Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
359                              dimShapesValues, dimSizesBuffer);
360     // Use the `reader` to parse the file.
361     Value tensor = NewCallParams(rewriter, loc)
362                        .genBuffers(stt, dimShapesValues, dimSizesBuffer)
363                        .genNewCall(Action::kFromReader, reader);
364     // Free the memory for `reader`.
365     createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
366                    EmitCInterface::Off);
367     rewriter.replaceOp(op, tensor);
368     return success();
369   }
370 };
371 
372 /// Sparse conversion rule for the alloc operator.
373 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
374 class SparseTensorAllocConverter
375     : public OpConversionPattern<bufferization::AllocTensorOp> {
376 public:
377   using OpConversionPattern::OpConversionPattern;
378   LogicalResult
379   matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
380                   ConversionPatternRewriter &rewriter) const override {
381     if (op.getCopy())
382       return rewriter.notifyMatchFailure(op,
383                                          "sparse tensor copy not implemented");
384     Location loc = op.getLoc();
385     const auto stt = getSparseTensorType(op);
386     if (!stt.hasEncoding())
387       return failure();
388     // Gather all dimension sizes as SSA values.
389     const Dimension dimRank = stt.getDimRank();
390     SmallVector<Value> dimSizes;
391     dimSizes.reserve(dimRank);
392     unsigned operandCtr = 0;
393     for (Dimension d = 0; d < dimRank; ++d) {
394       dimSizes.push_back(
395           stt.isDynamicDim(d)
396               ? adaptor.getOperands()[operandCtr++]
397               : constantIndex(rewriter, loc, op.getStaticSize(d)));
398     }
399     // Generate the call to construct empty tensor. The sizes are
400     // explicitly defined by the arguments to the alloc operator.
401     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
402                                .genBuffers(stt, dimSizes)
403                                .genNewCall(Action::kEmpty));
404     return success();
405   }
406 };
407 
408 /// Sparse conversion rule for the empty tensor.
409 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
410 public:
411   using OpConversionPattern::OpConversionPattern;
412   LogicalResult
413   matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
414                   ConversionPatternRewriter &rewriter) const override {
415     Location loc = op.getLoc();
416     const auto stt = getSparseTensorType(op);
417     if (!stt.hasEncoding())
418       return failure();
419     // Gather all dimension sizes as SSA values.
420     const Dimension dimRank = stt.getDimRank();
421     SmallVector<Value> dimSizes;
422     dimSizes.reserve(dimRank);
423     auto shape = op.getType().getShape();
424     unsigned operandCtr = 0;
425     for (Dimension d = 0; d < dimRank; ++d) {
426       dimSizes.push_back(stt.isDynamicDim(d)
427                              ? adaptor.getOperands()[operandCtr++]
428                              : constantIndex(rewriter, loc, shape[d]));
429     }
430     // Generate the call to construct empty tensor. The sizes are
431     // explicitly defined by the arguments to the alloc operator.
432     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
433                                .genBuffers(stt, dimSizes)
434                                .genNewCall(Action::kEmpty));
435     return success();
436   }
437 };
438 
439 /// Sparse conversion rule for the convert operator.
440 class SparseTensorReorderCOOConverter
441     : public OpConversionPattern<ReorderCOOOp> {
442 public:
443   using OpConversionPattern::OpConversionPattern;
444 
445   LogicalResult
446   matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
447                   ConversionPatternRewriter &rewriter) const override {
448     const Location loc = op->getLoc();
449     const auto srcTp = getSparseTensorType(op.getInputCoo());
450     const auto dstTp = getSparseTensorType(op);
451 
452     const Value src = adaptor.getInputCoo();
453 
454     NewCallParams params(rewriter, loc);
455     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
456     rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizes)
457                                .genNewCall(Action::kSortCOOInPlace, src));
458 
459     return success();
460   }
461 };
462 
463 /// Sparse conversion rule for the dealloc operator.
464 class SparseTensorDeallocConverter
465     : public OpConversionPattern<bufferization::DeallocTensorOp> {
466 public:
467   using OpConversionPattern::OpConversionPattern;
468   LogicalResult
469   matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
470                   ConversionPatternRewriter &rewriter) const override {
471     if (!getSparseTensorType(op.getTensor()).hasEncoding())
472       return failure();
473     StringRef name = "delSparseTensor";
474     createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
475                    EmitCInterface::Off);
476     rewriter.eraseOp(op);
477     return success();
478   }
479 };
480 
481 /// Sparse conversion rule for position accesses.
482 class SparseTensorToPositionsConverter
483     : public OpConversionPattern<ToPositionsOp> {
484 public:
485   using OpConversionPattern::OpConversionPattern;
486   LogicalResult
487   matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
488                   ConversionPatternRewriter &rewriter) const override {
489     Type resTp = op.getType();
490     Type posTp = cast<ShapedType>(resTp).getElementType();
491     SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
492     Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel());
493     replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl},
494                           EmitCInterface::On);
495     return success();
496   }
497 };
498 
499 /// Sparse conversion rule for coordinate accesses.
500 class SparseTensorToCoordinatesConverter
501     : public OpConversionPattern<ToCoordinatesOp> {
502 public:
503   using OpConversionPattern::OpConversionPattern;
504   LogicalResult
505   matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
506                   ConversionPatternRewriter &rewriter) const override {
507     // TODO: use `SparseTensorType::getCrdType` instead.
508     Type resType = op.getType();
509     const Type crdTp = cast<ShapedType>(resType).getElementType();
510     SmallString<19> name{"sparseCoordinates",
511                          overheadTypeFunctionSuffix(crdTp)};
512     Location loc = op->getLoc();
513     Value lvl = constantIndex(rewriter, loc, op.getLevel());
514 
515     // The function returns a MemRef without a layout.
516     MemRefType callRetType = get1DMemRefType(crdTp, false);
517     SmallVector<Value> operands{adaptor.getTensor(), lvl};
518     auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType,
519                       operands, EmitCInterface::On);
520     Value callRet =
521         rewriter.create<func::CallOp>(loc, callRetType, fn, operands)
522             .getResult(0);
523 
524     // Cast the MemRef type to the type expected by the users, though these
525     // two types should be compatible at runtime.
526     if (resType != callRetType)
527       callRet = rewriter.create<memref::CastOp>(loc, resType, callRet);
528     rewriter.replaceOp(op, callRet);
529 
530     return success();
531   }
532 };
533 
534 /// Sparse conversion rule for value accesses.
535 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
536 public:
537   using OpConversionPattern::OpConversionPattern;
538   LogicalResult
539   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
540                   ConversionPatternRewriter &rewriter) const override {
541     auto resType = cast<ShapedType>(op.getType());
542     rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
543                                          adaptor.getOperands()));
544     return success();
545   }
546 };
547 
548 /// Sparse conversion rule for number of entries operator.
549 class SparseNumberOfEntriesConverter
550     : public OpConversionPattern<NumberOfEntriesOp> {
551 public:
552   using OpConversionPattern::OpConversionPattern;
553   LogicalResult
554   matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
555                   ConversionPatternRewriter &rewriter) const override {
556     Location loc = op.getLoc();
557     // Query values array size for the actually stored values size.
558     Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType();
559     auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType);
560     Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
561     rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
562                                                constantIndex(rewriter, loc, 0));
563     return success();
564   }
565 };
566 
567 /// Sparse conversion rule for tensor rematerialization.
568 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
569 public:
570   using OpConversionPattern::OpConversionPattern;
571   LogicalResult
572   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
573                   ConversionPatternRewriter &rewriter) const override {
574     if (op.getHasInserts()) {
575       // Finalize any pending insertions.
576       StringRef name = "endLexInsert";
577       createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
578                      EmitCInterface::Off);
579     }
580     rewriter.replaceOp(op, adaptor.getOperands());
581     return success();
582   }
583 };
584 
585 /// Sparse conversion rule for the insertion operator.
586 class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
587 public:
588   using OpConversionPattern::OpConversionPattern;
589   LogicalResult
590   matchAndRewrite(InsertOp op, OpAdaptor adaptor,
591                   ConversionPatternRewriter &rewriter) const override {
592     // Note that the current regime only allows for strict lexicographic
593     // coordinate order. All values are passed by reference through stack
594     // allocated memrefs.
595     Location loc = op->getLoc();
596     const auto stt = getSparseTensorType(op.getTensor());
597     const auto elemTp = stt.getElementType();
598     const Level lvlRank = stt.getLvlRank();
599     Value lvlCoords, vref;
600     {
601       OpBuilder::InsertionGuard guard(rewriter);
602       Operation *loop = op;
603       // Finds the outermost loop.
604       while (auto l = loop->getParentOfType<LoopLikeOpInterface>())
605         loop = l;
606 
607       if (llvm::isa<LoopLikeOpInterface>(loop)) {
608         // Hoists alloca outside the loop to avoid stack overflow.
609         rewriter.setInsertionPoint(loop);
610       }
611       lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
612       vref = genAllocaScalar(rewriter, loc, elemTp);
613     }
614     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
615     rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref);
616     SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
617     createFuncCall(rewriter, loc, name, {},
618                    {adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On);
619     rewriter.replaceOp(op, adaptor.getTensor());
620     return success();
621   }
622 };
623 
624 /// Sparse conversion rule for the expand operator.
625 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
626 public:
627   using OpConversionPattern::OpConversionPattern;
628   LogicalResult
629   matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
630                   ConversionPatternRewriter &rewriter) const override {
631     Location loc = op->getLoc();
632     const auto srcTp = getSparseTensorType(op.getTensor());
633     Type eltType = srcTp.getElementType();
634     Type boolType = rewriter.getIntegerType(1);
635     Type idxType = rewriter.getIndexType();
636     // All initialization should be done on entry of the loop nest.
637     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
638     // Get the cardinality of valid coordinates for the innermost level.
639     Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
640                                    srcTp.getLvlRank() - 1);
641     // Allocate temporary buffers for values, filled-switch, and coordinates.
642     // We do not use stack buffers for this, since the expanded size may
643     // be rather large (as it envelops a single expanded dense dimension).
644     Value values = genAlloc(rewriter, loc, sz, eltType);
645     Value filled = genAlloc(rewriter, loc, sz, boolType);
646     Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
647     Value zero = constantZero(rewriter, loc, idxType);
648     // Reset the values/filled-switch to all-zero/false. Note that this
649     // introduces an O(N) operation into the computation, but this reset
650     // operation is amortized over the innermost loops for the access
651     // pattern expansion. As noted in the operation doc, we would like
652     // to amortize this setup cost even between kernels.
653     rewriter.create<linalg::FillOp>(
654         loc, ValueRange{constantZero(rewriter, loc, eltType)},
655         ValueRange{values});
656     rewriter.create<linalg::FillOp>(
657         loc, ValueRange{constantZero(rewriter, loc, boolType)},
658         ValueRange{filled});
659     // Replace expansion op with these buffers and initial coordinate.
660     assert(op.getNumResults() == 4);
661     rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
662     return success();
663   }
664 };
665 
666 /// Sparse conversion rule for the compress operator.
667 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
668 public:
669   using OpConversionPattern::OpConversionPattern;
670   LogicalResult
671   matchAndRewrite(CompressOp op, OpAdaptor adaptor,
672                   ConversionPatternRewriter &rewriter) const override {
673     Location loc = op->getLoc();
674     // Note that this method call resets the values/filled-switch back to
675     // all-zero/false by only iterating over the set elements, so the
676     // complexity remains proportional to the sparsity of the expanded
677     // access pattern.
678     Value values = adaptor.getValues();
679     Value filled = adaptor.getFilled();
680     Value added = adaptor.getAdded();
681     Value count = adaptor.getCount();
682     Value tensor = adaptor.getTensor();
683     const auto stt = getSparseTensorType(op.getTensor());
684     const Type elemTp = stt.getElementType();
685     const Level lvlRank = stt.getLvlRank();
686     auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
687     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
688     SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
689     createFuncCall(rewriter, loc, name, {},
690                    {tensor, lvlCoords, values, filled, added, count},
691                    EmitCInterface::On);
692     rewriter.replaceOp(op, adaptor.getTensor());
693     // Deallocate the buffers on exit of the loop nest.
694     Operation *parent = getTop(op);
695     rewriter.setInsertionPointAfter(parent);
696     rewriter.create<memref::DeallocOp>(loc, values);
697     rewriter.create<memref::DeallocOp>(loc, filled);
698     rewriter.create<memref::DeallocOp>(loc, added);
699     return success();
700   }
701 };
702 
703 /// Sparse conversion rule for the sparse_tensor.pack operator.
704 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
705 public:
706   using OpConversionPattern::OpConversionPattern;
707   LogicalResult
708   matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
709                   ConversionPatternRewriter &rewriter) const override {
710     const Location loc = op->getLoc();
711     const auto dstTp = getSparseTensorType(op.getResult());
712     // AssembleOps always returns a static shaped tensor result.
713     assert(dstTp.hasStaticDimShape());
714     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
715     Value dst =
716         NewCallParams(rewriter, loc)
717             .genBuffers(dstTp.withoutDimToLvl(), dimSizes)
718             .genNewCall(Action::kPack,
719                         genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
720                                           adaptor.getValues()));
721     rewriter.replaceOp(op, dst);
722     return success();
723   }
724 };
725 
726 } // namespace
727 
728 //===----------------------------------------------------------------------===//
729 // Sparse tensor type conversion into opaque pointer.
730 //===----------------------------------------------------------------------===//
731 
732 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
733   addConversion([](Type type) { return type; });
734   addConversion(convertSparseTensorTypes);
735 }
736 
737 //===----------------------------------------------------------------------===//
738 // Public method for populating conversion rules.
739 //===----------------------------------------------------------------------===//
740 
741 /// Populates the given patterns list with conversion rules required for
742 /// the sparsification of linear algebra operations.
743 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
744                                                   RewritePatternSet &patterns) {
745   patterns
746       .add<SparseReturnConverter, SparseTensorLvlOpConverter,
747            SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
748            SparseTensorAllocConverter, SparseTensorEmptyConverter,
749            SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
750            SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
751            SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
752            SparseTensorLoadConverter, SparseTensorInsertConverter,
753            SparseTensorExpandConverter, SparseTensorCompressConverter,
754            SparseTensorAssembleConverter>(typeConverter, patterns.getContext());
755 }
756