xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (revision af8428c0d9cd8daca8fd7d459a334707b00e2d3d)
1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor primitives into calls into a runtime
10 // support library. Sparse tensor types are converted into opaque pointers
11 // to the underlying sparse storage schemes. The use of opaque pointers
12 // together with runtime support library keeps the conversion relatively
13 // simple, but at the expense of IR opacity, which obscures opportunities
14 // for subsequent optimization of the IR. An alternative is provided by
15 // the SparseTensorCodegen pass.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "CodegenUtils.h"
20 
21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
30 #include "mlir/Dialect/Tensor/IR/Tensor.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 
33 using namespace mlir;
34 using namespace mlir::sparse_tensor;
35 
36 namespace {
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Maps each sparse tensor type to an opaque pointer.
43 static std::optional<Type> convertSparseTensorTypes(Type type) {
44   if (getSparseTensorEncoding(type) != nullptr)
45     return LLVM::LLVMPointerType::get(type.getContext());
46   return std::nullopt;
47 }
48 
49 /// Generates call to lookup a level-size.  N.B., this only generates
50 /// the raw function call, and therefore (intentionally) does not perform
51 /// any dim<->lvl conversion or other logic.
52 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
53                             uint64_t lvl) {
54   StringRef name = "sparseLvlSize";
55   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
56   Type iTp = builder.getIndexType();
57   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
58       .getResult(0);
59 }
60 
61 /// Generates call to lookup a dimension-size.  N.B., this only generates
62 /// the raw function call, and therefore (intentionally) does not perform
63 /// any dim<->lvl conversion or other logic.
64 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
65                             uint64_t dim) {
66   StringRef name = "sparseDimSize";
67   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
68   Type iTp = builder.getIndexType();
69   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
70       .getResult(0);
71 }
72 
73 /// Looks up a level-size by returning a statically-computed constant
74 /// (when possible), or by calling `genLvlSizeCall` (when dynamic).
75 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
76                                  SparseTensorType stt, Value tensor,
77                                  Level lvl) {
78   // Only sparse tensors have "levels" to query.
79   assert(stt.hasEncoding());
80   // TODO: The following implementation only handles permutations;
81   // we'll need to generalize this to handle arbitrary AffineExpr.
82   //
83   // There's no need to assert `isPermutation` here: because
84   // `getDimPosition` checks that the expr isa `AffineDimExpr`,
85   // which is all we care about (for supporting permutations).
86   const Dimension dim =
87       stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
88   const Size sz = stt.getDynamicDimSize(dim);
89   if (!ShapedType::isDynamic(sz))
90     return constantIndex(builder, loc, sz);
91   // If we cannot statically compute the size from the shape, then we
92   // must dynamically query it.  (In principle we could also dynamically
93   // compute it, but since we already did so to construct the `tensor`
94   // in the first place, we might as well query rather than recompute.)
95   return genLvlSizeCall(builder, loc, tensor, lvl);
96 }
97 
98 /// Looks up a dimension-size by returning a constant from the shape
99 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
100 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
101 /// of dense tensors).
102 static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
103                                  SparseTensorType stt, Value tensor,
104                                  Dimension dim) {
105   const Size sz = stt.getDynamicDimSize(dim);
106   if (!ShapedType::isDynamic(sz))
107     return constantIndex(builder, loc, sz);
108   if (stt.hasEncoding())
109     return genDimSizeCall(builder, loc, tensor, dim);
110   return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
111 }
112 
113 /// Populates the array with the dimension-sizes of the given tensor.
114 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
115                          Value tensor, SmallVectorImpl<Value> &out) {
116   const Dimension dimRank = stt.getDimRank();
117   out.clear();
118   out.reserve(dimRank);
119   for (Dimension d = 0; d < dimRank; d++)
120     out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
121 }
122 
123 /// Returns an array with the dimension-sizes of the given tensor.
124 /// If the *tensor* parameters is null, the tensor type is assumed to have a
125 /// static shape.
126 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
127                                       SparseTensorType stt,
128                                       Value tensor = Value()) {
129   SmallVector<Value> out;
130   fillDimSizes(builder, loc, stt, tensor, out);
131   return out;
132 }
133 
134 /// Generates an uninitialized buffer of the given size and type,
135 /// but returns it as type `memref<? x $tp>` (rather than as type
136 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
137 /// this buffer must be explicitly deallocated by client.
138 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
139   auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
140   return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
141 }
142 
143 /// Generates a temporary buffer for the level-types of the given encoding.
144 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
145                                SparseTensorType stt) {
146   SmallVector<Value> lvlTypes;
147   lvlTypes.reserve(stt.getLvlRank());
148   for (const auto dlt : stt.getEncoding().getLvlTypes())
149     lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
150   return allocaBuffer(builder, loc, lvlTypes);
151 }
152 
153 /// Extracts the bare (aligned) pointers that point to the tensor.
154 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
155                                       Value tensor) {
156   auto buf = genToMemref(builder, loc, tensor);
157   return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf);
158 }
159 
160 /// Generates a temporary buffer for the level-types of the given encoding.
161 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
162                                ValueRange lvlTensors, Value valTensor) {
163   SmallVector<Value> lvlBarePtrs;
164   lvlBarePtrs.reserve(lvlTensors.size() + 1);
165   // Passing in lvl buffer pointers.
166   for (const auto lvl : lvlTensors)
167     lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
168 
169   // Passing in value buffer pointers.
170   lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
171   Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>(
172       loc, allocaBuffer(builder, loc, lvlBarePtrs));
173   Value idxCast =
174       builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr);
175   return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder),
176                                           idxCast);
177 }
178 
179 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
180 /// the "swiss army knife" method of the sparse runtime support library
181 /// for materializing sparse tensors into the computation. This abstraction
182 /// reduces the need for modifications when the API changes.
183 class NewCallParams final {
184 public:
185   /// Allocates the `ValueRange` for the `func::CallOp` parameters.
186   NewCallParams(OpBuilder &builder, Location loc)
187       : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
188 
189   /// Initializes all static parameters (i.e., those which indicate
190   /// type-level information such as the encoding and sizes), generating
191   /// MLIR buffers as needed, and returning `this` for method chaining.
192   NewCallParams &genBuffers(SparseTensorType stt,
193                             ArrayRef<Value> dimSizesValues,
194                             Value dimSizesBuffer = Value()) {
195     assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
196     // Sparsity annotations.
197     params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
198     // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
199     params[kParamDimSizes] = dimSizesBuffer
200                                  ? dimSizesBuffer
201                                  : allocaBuffer(builder, loc, dimSizesValues);
202     params[kParamLvlSizes] =
203         genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes],
204                       params[kParamDim2Lvl], params[kParamLvl2Dim]);
205     // Secondary and primary types encoding.
206     const auto enc = stt.getEncoding();
207     params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
208     params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc);
209     params[kParamValTp] =
210         constantPrimaryTypeEncoding(builder, loc, stt.getElementType());
211     // Return `this` for method chaining.
212     return *this;
213   }
214 
215   /// Checks whether all the static parameters have been initialized.
216   bool isInitialized() const {
217     for (unsigned i = 0; i < kNumStaticParams; ++i)
218       if (!params[i])
219         return false;
220     return true;
221   }
222 
223   /// Generates a function call, with the current static parameters
224   /// and the given dynamic arguments.
225   Value genNewCall(Action action, Value ptr = Value()) {
226     assert(isInitialized() && "Must initialize before genNewCall");
227     StringRef name = "newSparseTensor";
228     params[kParamAction] = constantAction(builder, loc, action);
229     params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp);
230     return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
231         .getResult(0);
232   }
233 
234 private:
235   static constexpr unsigned kNumStaticParams = 8;
236   static constexpr unsigned kNumDynamicParams = 2;
237   static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
238   static constexpr unsigned kParamDimSizes = 0;
239   static constexpr unsigned kParamLvlSizes = 1;
240   static constexpr unsigned kParamLvlTypes = 2;
241   static constexpr unsigned kParamDim2Lvl = 3;
242   static constexpr unsigned kParamLvl2Dim = 4;
243   static constexpr unsigned kParamPosTp = 5;
244   static constexpr unsigned kParamCrdTp = 6;
245   static constexpr unsigned kParamValTp = 7;
246   static constexpr unsigned kParamAction = 8;
247   static constexpr unsigned kParamPtr = 9;
248 
249   OpBuilder &builder;
250   Location loc;
251   Type pTp;
252   Value params[kNumParams];
253 };
254 
255 /// Generates a call to obtain the values array.
256 static Value genValuesCall(OpBuilder &builder, Location loc,
257                            SparseTensorType stt, Value ptr) {
258   auto eltTp = stt.getElementType();
259   auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
260   SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)};
261   return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
262       .getResult(0);
263 }
264 
265 /// Generates a call to obtain the positions array.
266 static Value genPositionsCall(OpBuilder &builder, Location loc,
267                               SparseTensorType stt, Value ptr, Level l) {
268   Type posTp = stt.getPosType();
269   auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
270   Value lvl = constantIndex(builder, loc, l);
271   SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
272   return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
273                         EmitCInterface::On)
274       .getResult(0);
275 }
276 
277 /// Generates a call to obtain the coordindates array.
278 static Value genCoordinatesCall(OpBuilder &builder, Location loc,
279                                 SparseTensorType stt, Value ptr, Level l) {
280   Type crdTp = stt.getCrdType();
281   auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
282   Value lvl = constantIndex(builder, loc, l);
283   SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)};
284   return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
285                         EmitCInterface::On)
286       .getResult(0);
287 }
288 
289 //===----------------------------------------------------------------------===//
290 // Conversion rules.
291 //===----------------------------------------------------------------------===//
292 
293 /// Sparse conversion rule for returns.
294 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
295 public:
296   using OpConversionPattern::OpConversionPattern;
297   LogicalResult
298   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
299                   ConversionPatternRewriter &rewriter) const override {
300     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
301     return success();
302   }
303 };
304 
305 /// Sparse conversion rule for accessing level-sizes.
306 class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> {
307 public:
308   using OpConversionPattern::OpConversionPattern;
309   LogicalResult
310   matchAndRewrite(LvlOp op, OpAdaptor adaptor,
311                   ConversionPatternRewriter &rewriter) const override {
312     const auto stt = getSparseTensorType(op.getSource());
313     // Only rewrite sparse DimOp.
314     if (!stt.hasEncoding())
315       return failure();
316 
317     // Only rewrite DimOp with constant index.
318     std::optional<int64_t> lvl = op.getConstantLvlIndex();
319 
320     if (!lvl)
321       return failure();
322 
323     // By now, if the level size is constant, the operation should have already
324     // been folded by LvlOp's folder, so we generate the call unconditionally.
325     Value src = adaptor.getOperands()[0];
326     rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
327     return success();
328   }
329 };
330 
331 /// Sparse conversion rule for trivial tensor casts.
332 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
333 public:
334   using OpConversionPattern::OpConversionPattern;
335   LogicalResult
336   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
337                   ConversionPatternRewriter &rewriter) const override {
338     // Only rewrite identically annotated source/dest.
339     auto encDst = getSparseTensorEncoding(op.getType());
340     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
341     if (!encDst || encDst != encSrc)
342       return failure();
343     rewriter.replaceOp(op, adaptor.getOperands());
344     return success();
345   }
346 };
347 
348 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
349 public:
350   using OpConversionPattern::OpConversionPattern;
351   LogicalResult
352   matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
353                   ConversionPatternRewriter &rewriter) const override {
354     // Simply fold the operation.
355     rewriter.replaceOp(op, adaptor.getSource());
356     return success();
357   }
358 };
359 
360 /// Sparse conversion rule for the new operator.
361 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
362 public:
363   using OpConversionPattern::OpConversionPattern;
364   LogicalResult
365   matchAndRewrite(NewOp op, OpAdaptor adaptor,
366                   ConversionPatternRewriter &rewriter) const override {
367     Location loc = op.getLoc();
368     const auto stt = getSparseTensorType(op);
369     if (!stt.hasEncoding())
370       return failure();
371     // Construct the `reader` opening method calls.
372     SmallVector<Value> dimShapesValues;
373     Value dimSizesBuffer;
374     Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
375                              dimShapesValues, dimSizesBuffer);
376     // Use the `reader` to parse the file.
377     Value tensor = NewCallParams(rewriter, loc)
378                        .genBuffers(stt, dimShapesValues, dimSizesBuffer)
379                        .genNewCall(Action::kFromReader, reader);
380     // Free the memory for `reader`.
381     createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
382                    EmitCInterface::Off);
383     rewriter.replaceOp(op, tensor);
384     return success();
385   }
386 };
387 
388 /// Sparse conversion rule for the alloc operator.
389 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
390 class SparseTensorAllocConverter
391     : public OpConversionPattern<bufferization::AllocTensorOp> {
392 public:
393   using OpConversionPattern::OpConversionPattern;
394   LogicalResult
395   matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
396                   ConversionPatternRewriter &rewriter) const override {
397     const auto stt = getSparseTensorType(op);
398     if (!stt.hasEncoding())
399       return failure();
400     if (op.getCopy())
401       return rewriter.notifyMatchFailure(op, "alloc copy not implemented");
402     // Gather all dimension sizes as SSA values.
403     Location loc = op.getLoc();
404     const Dimension dimRank = stt.getDimRank();
405     SmallVector<Value> dimSizes;
406     dimSizes.reserve(dimRank);
407     unsigned operandCtr = 0;
408     for (Dimension d = 0; d < dimRank; d++) {
409       dimSizes.push_back(
410           stt.isDynamicDim(d)
411               ? adaptor.getOperands()[operandCtr++]
412               : constantIndex(rewriter, loc, op.getStaticSize(d)));
413     }
414     // Generate the call to construct empty tensor. The sizes are
415     // explicitly defined by the arguments to the alloc operator.
416     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
417                                .genBuffers(stt, dimSizes)
418                                .genNewCall(Action::kEmpty));
419     return success();
420   }
421 };
422 
423 /// Sparse conversion rule for the empty tensor.
424 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
425 public:
426   using OpConversionPattern::OpConversionPattern;
427   LogicalResult
428   matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
429                   ConversionPatternRewriter &rewriter) const override {
430     Location loc = op.getLoc();
431     const auto stt = getSparseTensorType(op);
432     if (!stt.hasEncoding())
433       return failure();
434     // Gather all dimension sizes as SSA values.
435     const Dimension dimRank = stt.getDimRank();
436     SmallVector<Value> dimSizes;
437     dimSizes.reserve(dimRank);
438     auto shape = op.getType().getShape();
439     unsigned operandCtr = 0;
440     for (Dimension d = 0; d < dimRank; d++) {
441       dimSizes.push_back(stt.isDynamicDim(d)
442                              ? adaptor.getOperands()[operandCtr++]
443                              : constantIndex(rewriter, loc, shape[d]));
444     }
445     // Generate the call to construct empty tensor. The sizes are
446     // explicitly defined by the arguments to the alloc operator.
447     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
448                                .genBuffers(stt, dimSizes)
449                                .genNewCall(Action::kEmpty));
450     return success();
451   }
452 };
453 
454 /// Sparse conversion rule for the convert operator.
455 class SparseTensorReorderCOOConverter
456     : public OpConversionPattern<ReorderCOOOp> {
457 public:
458   using OpConversionPattern::OpConversionPattern;
459 
460   LogicalResult
461   matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
462                   ConversionPatternRewriter &rewriter) const override {
463     const Location loc = op->getLoc();
464     const auto srcTp = getSparseTensorType(op.getInputCoo());
465     const auto dstTp = getSparseTensorType(op);
466 
467     const Value src = adaptor.getInputCoo();
468 
469     NewCallParams params(rewriter, loc);
470     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
471     rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizes)
472                                .genNewCall(Action::kSortCOOInPlace, src));
473 
474     return success();
475   }
476 };
477 
478 /// Sparse conversion rule for the dealloc operator.
479 class SparseTensorDeallocConverter
480     : public OpConversionPattern<bufferization::DeallocTensorOp> {
481 public:
482   using OpConversionPattern::OpConversionPattern;
483   LogicalResult
484   matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
485                   ConversionPatternRewriter &rewriter) const override {
486     if (!getSparseTensorType(op.getTensor()).hasEncoding())
487       return failure();
488     StringRef name = "delSparseTensor";
489     createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
490                    EmitCInterface::Off);
491     rewriter.eraseOp(op);
492     return success();
493   }
494 };
495 
496 /// Sparse conversion rule for position accesses.
497 class SparseTensorToPositionsConverter
498     : public OpConversionPattern<ToPositionsOp> {
499 public:
500   using OpConversionPattern::OpConversionPattern;
501   LogicalResult
502   matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
503                   ConversionPatternRewriter &rewriter) const override {
504     auto stt = getSparseTensorType(op.getTensor());
505     auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
506                                  adaptor.getTensor(), op.getLevel());
507     rewriter.replaceOp(op, poss);
508     return success();
509   }
510 };
511 
512 /// Sparse conversion rule for coordinate accesses.
513 class SparseTensorToCoordinatesConverter
514     : public OpConversionPattern<ToCoordinatesOp> {
515 public:
516   using OpConversionPattern::OpConversionPattern;
517   LogicalResult
518   matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
519                   ConversionPatternRewriter &rewriter) const override {
520     auto stt = getSparseTensorType(op.getTensor());
521     auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
522                                    adaptor.getTensor(), op.getLevel());
523     // Cast the MemRef type to the type expected by the users, though these
524     // two types should be compatible at runtime.
525     if (op.getType() != crds.getType())
526       crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
527     rewriter.replaceOp(op, crds);
528     return success();
529   }
530 };
531 
532 /// Sparse conversion rule for value accesses.
533 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
534 public:
535   using OpConversionPattern::OpConversionPattern;
536   LogicalResult
537   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
538                   ConversionPatternRewriter &rewriter) const override {
539     auto stt = getSparseTensorType(op.getTensor());
540     auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
541     rewriter.replaceOp(op, vals);
542     return success();
543   }
544 };
545 
546 /// Sparse conversion rule for number of entries operator.
547 class SparseNumberOfEntriesConverter
548     : public OpConversionPattern<NumberOfEntriesOp> {
549 public:
550   using OpConversionPattern::OpConversionPattern;
551   LogicalResult
552   matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
553                   ConversionPatternRewriter &rewriter) const override {
554     // Query values array size for the actually stored values size.
555     auto stt = getSparseTensorType(op.getTensor());
556     auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
557     auto zero = constantIndex(rewriter, op.getLoc(), 0);
558     rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
559     return success();
560   }
561 };
562 
563 /// Sparse conversion rule for tensor rematerialization.
564 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
565 public:
566   using OpConversionPattern::OpConversionPattern;
567   LogicalResult
568   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
569                   ConversionPatternRewriter &rewriter) const override {
570     if (op.getHasInserts()) {
571       // Finalize any pending insertions.
572       StringRef name = "endLexInsert";
573       createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
574                      EmitCInterface::Off);
575     }
576     rewriter.replaceOp(op, adaptor.getOperands());
577     return success();
578   }
579 };
580 
581 /// Sparse conversion rule for the insertion operator.
582 class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
583 public:
584   using OpConversionPattern::OpConversionPattern;
585   LogicalResult
586   matchAndRewrite(InsertOp op, OpAdaptor adaptor,
587                   ConversionPatternRewriter &rewriter) const override {
588     // Note that the current regime only allows for strict lexicographic
589     // coordinate order. All values are passed by reference through stack
590     // allocated memrefs.
591     Location loc = op->getLoc();
592     const auto stt = getSparseTensorType(op.getTensor());
593     const auto elemTp = stt.getElementType();
594     const Level lvlRank = stt.getLvlRank();
595     Value lvlCoords, vref;
596     {
597       OpBuilder::InsertionGuard guard(rewriter);
598       Operation *loop = op;
599       // Finds the outermost loop.
600       while (auto l = loop->getParentOfType<LoopLikeOpInterface>())
601         loop = l;
602 
603       if (llvm::isa<LoopLikeOpInterface>(loop)) {
604         // Hoists alloca outside the loop to avoid stack overflow.
605         rewriter.setInsertionPoint(loop);
606       }
607       lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
608       vref = genAllocaScalar(rewriter, loc, elemTp);
609     }
610     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
611     rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref);
612     SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
613     createFuncCall(rewriter, loc, name, {},
614                    {adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On);
615     rewriter.replaceOp(op, adaptor.getTensor());
616     return success();
617   }
618 };
619 
620 /// Sparse conversion rule for the expand operator.
621 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
622 public:
623   using OpConversionPattern::OpConversionPattern;
624   LogicalResult
625   matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
626                   ConversionPatternRewriter &rewriter) const override {
627     Location loc = op->getLoc();
628     const auto srcTp = getSparseTensorType(op.getTensor());
629     Type eltType = srcTp.getElementType();
630     Type boolType = rewriter.getIntegerType(1);
631     Type idxType = rewriter.getIndexType();
632     // All initialization should be done on entry of the loop nest.
633     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
634     // Get the cardinality of valid coordinates for the innermost level.
635     Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
636                                    srcTp.getLvlRank() - 1);
637     // Allocate temporary buffers for values, filled-switch, and coordinates.
638     // We do not use stack buffers for this, since the expanded size may
639     // be rather large (as it envelops a single expanded dense dimension).
640     Value values = genAlloc(rewriter, loc, sz, eltType);
641     Value filled = genAlloc(rewriter, loc, sz, boolType);
642     Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
643     Value zero = constantZero(rewriter, loc, idxType);
644     // Reset the values/filled-switch to all-zero/false. Note that this
645     // introduces an O(N) operation into the computation, but this reset
646     // operation is amortized over the innermost loops for the access
647     // pattern expansion. As noted in the operation doc, we would like
648     // to amortize this setup cost even between kernels.
649     rewriter.create<linalg::FillOp>(
650         loc, ValueRange{constantZero(rewriter, loc, eltType)},
651         ValueRange{values});
652     rewriter.create<linalg::FillOp>(
653         loc, ValueRange{constantZero(rewriter, loc, boolType)},
654         ValueRange{filled});
655     // Replace expansion op with these buffers and initial coordinate.
656     assert(op.getNumResults() == 4);
657     rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
658     return success();
659   }
660 };
661 
662 /// Sparse conversion rule for the compress operator.
663 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
664 public:
665   using OpConversionPattern::OpConversionPattern;
666   LogicalResult
667   matchAndRewrite(CompressOp op, OpAdaptor adaptor,
668                   ConversionPatternRewriter &rewriter) const override {
669     Location loc = op->getLoc();
670     // Note that this method call resets the values/filled-switch back to
671     // all-zero/false by only iterating over the set elements, so the
672     // complexity remains proportional to the sparsity of the expanded
673     // access pattern.
674     Value values = adaptor.getValues();
675     Value filled = adaptor.getFilled();
676     Value added = adaptor.getAdded();
677     Value count = adaptor.getCount();
678     Value tensor = adaptor.getTensor();
679     const auto stt = getSparseTensorType(op.getTensor());
680     const Type elemTp = stt.getElementType();
681     const Level lvlRank = stt.getLvlRank();
682     auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
683     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
684     SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
685     createFuncCall(rewriter, loc, name, {},
686                    {tensor, lvlCoords, values, filled, added, count},
687                    EmitCInterface::On);
688     rewriter.replaceOp(op, adaptor.getTensor());
689     // Deallocate the buffers on exit of the loop nest.
690     Operation *parent = getTop(op);
691     rewriter.setInsertionPointAfter(parent);
692     rewriter.create<memref::DeallocOp>(loc, values);
693     rewriter.create<memref::DeallocOp>(loc, filled);
694     rewriter.create<memref::DeallocOp>(loc, added);
695     return success();
696   }
697 };
698 
699 /// Sparse conversion rule for the sparse_tensor.assemble operator.
700 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
701 public:
702   using OpConversionPattern::OpConversionPattern;
703   LogicalResult
704   matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
705                   ConversionPatternRewriter &rewriter) const override {
706     const Location loc = op->getLoc();
707     const auto dstTp = getSparseTensorType(op.getResult());
708     assert(dstTp.hasStaticDimShape());
709     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
710     // Use a library method to transfer the external buffers from
711     // clients to the internal SparseTensorStorage. Since we cannot
712     // assume clients transfer ownership of the buffers, this method
713     // will copy all data over into a new SparseTensorStorage.
714     Value dst =
715         NewCallParams(rewriter, loc)
716             .genBuffers(dstTp.withoutDimToLvl(), dimSizes)
717             .genNewCall(Action::kPack,
718                         genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
719                                           adaptor.getValues()));
720     rewriter.replaceOp(op, dst);
721     return success();
722   }
723 };
724 
725 /// Sparse conversion rule for the sparse_tensor.disassemble operator.
726 class SparseTensorDisassembleConverter
727     : public OpConversionPattern<DisassembleOp> {
728 public:
729   using OpConversionPattern::OpConversionPattern;
730   LogicalResult
731   matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
732                   ConversionPatternRewriter &rewriter) const override {
733     // We simply expose the buffers to the external client. This
734     // assumes the client only reads the buffers (usually copying it
735     // to the external data structures, such as numpy arrays).
736     Location loc = op->getLoc();
737     auto stt = getSparseTensorType(op.getTensor());
738     SmallVector<Value> retVal;
739     SmallVector<Value> retLen;
740     // Get the values buffer first.
741     auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
742     auto valLenTp = op.getValLen().getType();
743     auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
744     retVal.push_back(vals);
745     retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
746     // Then get the positions and coordinates buffers.
747     const Level lvlRank = stt.getLvlRank();
748     Level trailCOOLen = 0;
749     for (Level l = 0; l < lvlRank; l++) {
750       if (!stt.isUniqueLvl(l) &&
751           (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
752         // A `(loose)compressed_nu` level marks the start of trailing COO
753         // start level. Since the target coordinate buffer used for trailing
754         // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
755         // scheme, we cannot simply use the internal buffers.
756         trailCOOLen = lvlRank - l;
757         break;
758       }
759       if (stt.isWithPos(l)) {
760         auto poss =
761             genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
762         auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
763         auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
764         retVal.push_back(poss);
765         retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
766       }
767       if (stt.isWithCrd(l)) {
768         auto crds =
769             genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
770         auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
771         auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
772         retVal.push_back(crds);
773         retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
774       }
775     }
776     // Handle AoS vs. SoA mismatch for COO.
777     if (trailCOOLen != 0) {
778       uint64_t cooStartLvl = lvlRank - trailCOOLen;
779       assert(!stt.isUniqueLvl(cooStartLvl) &&
780              (stt.isCompressedLvl(cooStartLvl) ||
781               stt.isLooseCompressedLvl(cooStartLvl)));
782       // Positions.
783       auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
784                                    cooStartLvl);
785       auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
786       auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
787       retVal.push_back(poss);
788       retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
789       // Coordinates, copied over with:
790       //    for (i = 0; i < crdLen; i++)
791       //       buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
792       auto buf =
793           genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
794       auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
795                                       cooStartLvl);
796       auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
797                                       cooStartLvl + 1);
798       auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0);
799       auto two = constantIndex(rewriter, loc, 2);
800       auto bufLen = rewriter.create<arith::MulIOp>(loc, crdLen, two);
801       Type indexType = rewriter.getIndexType();
802       auto zero = constantZero(rewriter, loc, indexType);
803       auto one = constantOne(rewriter, loc, indexType);
804       scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, zero, crdLen, one);
805       auto idx = forOp.getInductionVar();
806       rewriter.setInsertionPointToStart(forOp.getBody());
807       auto c0 = rewriter.create<memref::LoadOp>(loc, crds0, idx);
808       auto c1 = rewriter.create<memref::LoadOp>(loc, crds1, idx);
809       SmallVector<Value> args;
810       args.push_back(idx);
811       args.push_back(zero);
812       rewriter.create<memref::StoreOp>(loc, c0, buf, args);
813       args[1] = one;
814       rewriter.create<memref::StoreOp>(loc, c1, buf, args);
815       rewriter.setInsertionPointAfter(forOp);
816       auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
817       retVal.push_back(buf);
818       retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
819     }
820     // Converts MemRefs back to Tensors.
821     assert(retVal.size() + retLen.size() == op.getNumResults());
822     for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
823       auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]);
824       retVal[i] =
825           rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
826     }
827     // Appends the actual memory length used in each buffer returned.
828     retVal.append(retLen.begin(), retLen.end());
829     rewriter.replaceOp(op, retVal);
830     return success();
831   }
832 };
833 
834 } // namespace
835 
836 //===----------------------------------------------------------------------===//
837 // Sparse tensor type conversion into opaque pointer.
838 //===----------------------------------------------------------------------===//
839 
840 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
841   addConversion([](Type type) { return type; });
842   addConversion(convertSparseTensorTypes);
843 }
844 
845 //===----------------------------------------------------------------------===//
846 // Public method for populating conversion rules.
847 //===----------------------------------------------------------------------===//
848 
849 /// Populates the given patterns list with conversion rules required for
850 /// the sparsification of linear algebra operations.
851 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
852                                                   RewritePatternSet &patterns) {
853   patterns
854       .add<SparseReturnConverter, SparseTensorLvlOpConverter,
855            SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
856            SparseTensorAllocConverter, SparseTensorEmptyConverter,
857            SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
858            SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
859            SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
860            SparseTensorLoadConverter, SparseTensorInsertConverter,
861            SparseTensorExpandConverter, SparseTensorCompressConverter,
862            SparseTensorAssembleConverter, SparseTensorDisassembleConverter>(
863           typeConverter, patterns.getContext());
864 }
865