xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (revision e8e8df4c1bf97f0674b2387175cdeb251a4e0d9c)
1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor primitives into calls into a runtime
10 // support library. Sparse tensor types are converted into opaque pointers
11 // to the underlying sparse storage schemes. The use of opaque pointers
12 // together with runtime support library keeps the conversion relatively
13 // simple, but at the expense of IR opacity, which obscures opportunities
14 // for subsequent optimization of the IR. An alternative is provided by
15 // the SparseTensorCodegen pass.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "Utils/CodegenUtils.h"
20 
21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
30 #include "mlir/Dialect/Tensor/IR/Tensor.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 
33 using namespace mlir;
34 using namespace mlir::sparse_tensor;
35 
36 namespace {
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Maps each sparse tensor type to an opaque pointer.
43 static std::optional<Type> convertSparseTensorTypes(Type type) {
44   if (getSparseTensorEncoding(type) != nullptr)
45     return LLVM::LLVMPointerType::get(type.getContext());
46   return std::nullopt;
47 }
48 
49 /// Generates call to lookup a level-size.  N.B., this only generates
50 /// the raw function call, and therefore (intentionally) does not perform
51 /// any dim<->lvl conversion or other logic.
52 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
53                             uint64_t lvl) {
54   StringRef name = "sparseLvlSize";
55   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
56   Type iTp = builder.getIndexType();
57   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
58       .getResult(0);
59 }
60 
61 /// Generates call to lookup a dimension-size.  N.B., this only generates
62 /// the raw function call, and therefore (intentionally) does not perform
63 /// any dim<->lvl conversion or other logic.
64 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
65                             uint64_t dim) {
66   StringRef name = "sparseDimSize";
67   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
68   Type iTp = builder.getIndexType();
69   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
70       .getResult(0);
71 }
72 
73 /// Looks up a level-size by returning a statically-computed constant
74 /// (when possible), or by calling `genLvlSizeCall` (when dynamic).
75 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
76                                  SparseTensorType stt, Value tensor,
77                                  Level lvl) {
78   // Only sparse tensors have "levels" to query.
79   assert(stt.hasEncoding());
80   // TODO: The following implementation only handles permutations;
81   // we'll need to generalize this to handle arbitrary AffineExpr.
82   //
83   // There's no need to assert `isPermutation` here: because
84   // `getDimPosition` checks that the expr isa `AffineDimExpr`,
85   // which is all we care about (for supporting permutations).
86   const Dimension dim =
87       stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
88   const Size sz = stt.getDynamicDimSize(dim);
89   if (!ShapedType::isDynamic(sz))
90     return constantIndex(builder, loc, sz);
91   // If we cannot statically compute the size from the shape, then we
92   // must dynamically query it.  (In principle we could also dynamically
93   // compute it, but since we already did so to construct the `tensor`
94   // in the first place, we might as well query rather than recompute.)
95   return genLvlSizeCall(builder, loc, tensor, lvl);
96 }
97 
98 /// Looks up a dimension-size by returning a constant from the shape
99 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
100 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
101 /// of dense tensors).
102 static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
103                                  SparseTensorType stt, Value tensor,
104                                  Dimension dim) {
105   const Size sz = stt.getDynamicDimSize(dim);
106   if (!ShapedType::isDynamic(sz))
107     return constantIndex(builder, loc, sz);
108   if (stt.hasEncoding())
109     return genDimSizeCall(builder, loc, tensor, dim);
110   return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
111 }
112 
113 /// Populates the array with the dimension-sizes of the given tensor.
114 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
115                          Value tensor, SmallVectorImpl<Value> &out) {
116   const Dimension dimRank = stt.getDimRank();
117   out.clear();
118   out.reserve(dimRank);
119   for (Dimension d = 0; d < dimRank; d++)
120     out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
121 }
122 
123 /// Returns an array with the dimension-sizes of the given tensor.
124 /// If the *tensor* parameters is null, the tensor type is assumed to have a
125 /// static shape.
126 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
127                                       SparseTensorType stt,
128                                       Value tensor = Value()) {
129   SmallVector<Value> out;
130   fillDimSizes(builder, loc, stt, tensor, out);
131   return out;
132 }
133 
134 /// Generates an uninitialized buffer of the given size and type,
135 /// but returns it as type `memref<? x $tp>` (rather than as type
136 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
137 /// this buffer must be explicitly deallocated by client.
138 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
139   auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
140   return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
141 }
142 
143 /// Generates a temporary buffer for the level-types of the given encoding.
144 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
145                                SparseTensorType stt) {
146   SmallVector<Value> lvlTypes;
147   lvlTypes.reserve(stt.getLvlRank());
148   for (const auto lt : stt.getEncoding().getLvlTypes())
149     lvlTypes.push_back(constantLevelTypeEncoding(builder, loc, lt));
150   return allocaBuffer(builder, loc, lvlTypes);
151 }
152 
153 /// Extracts the bare (aligned) pointers that point to the tensor.
154 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
155                                       Value tensor) {
156   auto buf = genToMemref(builder, loc, tensor);
157   return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf);
158 }
159 
160 /// Generates a temporary buffer for the level-types of the given encoding.
161 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
162                                ValueRange lvlTensors, Value valTensor) {
163   SmallVector<Value> lvlBarePtrs;
164   lvlBarePtrs.reserve(lvlTensors.size() + 1);
165   // Passing in lvl buffer pointers.
166   for (const auto lvl : lvlTensors)
167     lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
168 
169   // Passing in value buffer pointers.
170   lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
171   Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>(
172       loc, allocaBuffer(builder, loc, lvlBarePtrs));
173   Value idxCast =
174       builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr);
175   return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder),
176                                           idxCast);
177 }
178 
179 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
180 /// the "swiss army knife" method of the sparse runtime support library
181 /// for materializing sparse tensors into the computation. This abstraction
182 /// reduces the need for modifications when the API changes.
183 class NewCallParams final {
184 public:
185   /// Allocates the `ValueRange` for the `func::CallOp` parameters.
186   NewCallParams(OpBuilder &builder, Location loc)
187       : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
188 
189   /// Initializes all static parameters (i.e., those which indicate
190   /// type-level information such as the encoding and sizes), generating
191   /// MLIR buffers as needed, and returning `this` for method chaining.
192   NewCallParams &genBuffers(SparseTensorType stt,
193                             ArrayRef<Value> dimSizesValues,
194                             Value dimSizesBuffer = Value()) {
195     assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
196     // Sparsity annotations.
197     params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
198     // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
199     params[kParamDimSizes] = dimSizesBuffer
200                                  ? dimSizesBuffer
201                                  : allocaBuffer(builder, loc, dimSizesValues);
202     SmallVector<Value> lvlSizesValues; // unused
203     params[kParamLvlSizes] = genMapBuffers(
204         builder, loc, stt, dimSizesValues, params[kParamDimSizes],
205         lvlSizesValues, params[kParamDim2Lvl], params[kParamLvl2Dim]);
206     // Secondary and primary types encoding.
207     const auto enc = stt.getEncoding();
208     params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
209     params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc);
210     params[kParamValTp] =
211         constantPrimaryTypeEncoding(builder, loc, stt.getElementType());
212     // Return `this` for method chaining.
213     return *this;
214   }
215 
216   /// Checks whether all the static parameters have been initialized.
217   bool isInitialized() const {
218     for (unsigned i = 0; i < kNumStaticParams; ++i)
219       if (!params[i])
220         return false;
221     return true;
222   }
223 
224   /// Generates a function call, with the current static parameters
225   /// and the given dynamic arguments.
226   Value genNewCall(Action action, Value ptr = Value()) {
227     assert(isInitialized() && "Must initialize before genNewCall");
228     StringRef name = "newSparseTensor";
229     params[kParamAction] = constantAction(builder, loc, action);
230     params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp);
231     return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
232         .getResult(0);
233   }
234 
235 private:
236   static constexpr unsigned kNumStaticParams = 8;
237   static constexpr unsigned kNumDynamicParams = 2;
238   static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
239   static constexpr unsigned kParamDimSizes = 0;
240   static constexpr unsigned kParamLvlSizes = 1;
241   static constexpr unsigned kParamLvlTypes = 2;
242   static constexpr unsigned kParamDim2Lvl = 3;
243   static constexpr unsigned kParamLvl2Dim = 4;
244   static constexpr unsigned kParamPosTp = 5;
245   static constexpr unsigned kParamCrdTp = 6;
246   static constexpr unsigned kParamValTp = 7;
247   static constexpr unsigned kParamAction = 8;
248   static constexpr unsigned kParamPtr = 9;
249 
250   OpBuilder &builder;
251   Location loc;
252   Type pTp;
253   Value params[kNumParams];
254 };
255 
256 /// Generates a call to obtain the values array.
257 static Value genValuesCall(OpBuilder &builder, Location loc,
258                            SparseTensorType stt, Value ptr) {
259   auto eltTp = stt.getElementType();
260   auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
261   SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)};
262   return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
263       .getResult(0);
264 }
265 
266 /// Generates a call to obtain the positions array.
267 static Value genPositionsCall(OpBuilder &builder, Location loc,
268                               SparseTensorType stt, Value ptr, Level l) {
269   Type posTp = stt.getPosType();
270   auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
271   Value lvl = constantIndex(builder, loc, l);
272   SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
273   return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
274                         EmitCInterface::On)
275       .getResult(0);
276 }
277 
278 /// Generates a call to obtain the coordindates array.
279 static Value genCoordinatesCall(OpBuilder &builder, Location loc,
280                                 SparseTensorType stt, Value ptr, Level l) {
281   Type crdTp = stt.getCrdType();
282   auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
283   Value lvl = constantIndex(builder, loc, l);
284   SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)};
285   return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
286                         EmitCInterface::On)
287       .getResult(0);
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // Conversion rules.
292 //===----------------------------------------------------------------------===//
293 
294 /// Sparse conversion rule for returns.
295 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
296 public:
297   using OpConversionPattern::OpConversionPattern;
298   LogicalResult
299   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
300                   ConversionPatternRewriter &rewriter) const override {
301     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
302     return success();
303   }
304 };
305 
306 /// Sparse conversion rule for accessing level-sizes.
307 class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> {
308 public:
309   using OpConversionPattern::OpConversionPattern;
310   LogicalResult
311   matchAndRewrite(LvlOp op, OpAdaptor adaptor,
312                   ConversionPatternRewriter &rewriter) const override {
313     const auto stt = getSparseTensorType(op.getSource());
314     // Only rewrite sparse DimOp.
315     if (!stt.hasEncoding())
316       return failure();
317 
318     // Only rewrite DimOp with constant index.
319     std::optional<int64_t> lvl = op.getConstantLvlIndex();
320 
321     if (!lvl)
322       return failure();
323 
324     // By now, if the level size is constant, the operation should have already
325     // been folded by LvlOp's folder, so we generate the call unconditionally.
326     Value src = adaptor.getOperands()[0];
327     rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
328     return success();
329   }
330 };
331 
332 /// Sparse conversion rule for trivial tensor casts.
333 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
334 public:
335   using OpConversionPattern::OpConversionPattern;
336   LogicalResult
337   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
338                   ConversionPatternRewriter &rewriter) const override {
339     // Only rewrite identically annotated source/dest.
340     auto encDst = getSparseTensorEncoding(op.getType());
341     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
342     if (!encDst || encDst != encSrc)
343       return failure();
344     rewriter.replaceOp(op, adaptor.getOperands());
345     return success();
346   }
347 };
348 
349 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
350 public:
351   using OpConversionPattern::OpConversionPattern;
352   LogicalResult
353   matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
354                   ConversionPatternRewriter &rewriter) const override {
355     // Simply fold the operation.
356     rewriter.replaceOp(op, adaptor.getSource());
357     return success();
358   }
359 };
360 
361 /// Sparse conversion rule for the new operator.
362 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
363 public:
364   using OpConversionPattern::OpConversionPattern;
365   LogicalResult
366   matchAndRewrite(NewOp op, OpAdaptor adaptor,
367                   ConversionPatternRewriter &rewriter) const override {
368     Location loc = op.getLoc();
369     const auto stt = getSparseTensorType(op);
370     if (!stt.hasEncoding())
371       return failure();
372     // Construct the `reader` opening method calls.
373     SmallVector<Value> dimSizesValues;
374     Value dimSizesBuffer;
375     Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
376                              dimSizesValues, dimSizesBuffer);
377     // Use the `reader` to parse the file.
378     Value tensor = NewCallParams(rewriter, loc)
379                        .genBuffers(stt, dimSizesValues, dimSizesBuffer)
380                        .genNewCall(Action::kFromReader, reader);
381     // Free the memory for `reader`.
382     createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
383                    EmitCInterface::Off);
384     rewriter.replaceOp(op, tensor);
385     return success();
386   }
387 };
388 
389 /// Sparse conversion rule for the alloc operator.
390 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
391 class SparseTensorAllocConverter
392     : public OpConversionPattern<bufferization::AllocTensorOp> {
393 public:
394   using OpConversionPattern::OpConversionPattern;
395   LogicalResult
396   matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
397                   ConversionPatternRewriter &rewriter) const override {
398     const auto stt = getSparseTensorType(op);
399     if (!stt.hasEncoding())
400       return failure();
401     if (op.getCopy())
402       return rewriter.notifyMatchFailure(op, "alloc copy not implemented");
403     // Gather all dimension sizes as SSA values.
404     Location loc = op.getLoc();
405     const Dimension dimRank = stt.getDimRank();
406     SmallVector<Value> dimSizesValues;
407     dimSizesValues.reserve(dimRank);
408     unsigned operandCtr = 0;
409     for (Dimension d = 0; d < dimRank; d++) {
410       dimSizesValues.push_back(
411           stt.isDynamicDim(d)
412               ? adaptor.getOperands()[operandCtr++]
413               : constantIndex(rewriter, loc, op.getStaticSize(d)));
414     }
415     // Generate the call to construct empty tensor. The sizes are
416     // explicitly defined by the arguments to the alloc operator.
417     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
418                                .genBuffers(stt, dimSizesValues)
419                                .genNewCall(Action::kEmpty));
420     return success();
421   }
422 };
423 
424 /// Sparse conversion rule for the empty tensor.
425 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
426 public:
427   using OpConversionPattern::OpConversionPattern;
428   LogicalResult
429   matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
430                   ConversionPatternRewriter &rewriter) const override {
431     Location loc = op.getLoc();
432     const auto stt = getSparseTensorType(op);
433     if (!stt.hasEncoding())
434       return failure();
435     // Gather all dimension sizes as SSA values.
436     const Dimension dimRank = stt.getDimRank();
437     SmallVector<Value> dimSizesValues;
438     dimSizesValues.reserve(dimRank);
439     auto shape = op.getType().getShape();
440     unsigned operandCtr = 0;
441     for (Dimension d = 0; d < dimRank; d++) {
442       dimSizesValues.push_back(stt.isDynamicDim(d)
443                                    ? adaptor.getOperands()[operandCtr++]
444                                    : constantIndex(rewriter, loc, shape[d]));
445     }
446     // Generate the call to construct empty tensor. The sizes are
447     // explicitly defined by the arguments to the alloc operator.
448     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
449                                .genBuffers(stt, dimSizesValues)
450                                .genNewCall(Action::kEmpty));
451     return success();
452   }
453 };
454 
455 /// Sparse conversion rule for the convert operator.
456 class SparseTensorReorderCOOConverter
457     : public OpConversionPattern<ReorderCOOOp> {
458 public:
459   using OpConversionPattern::OpConversionPattern;
460 
461   LogicalResult
462   matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
463                   ConversionPatternRewriter &rewriter) const override {
464     const Location loc = op->getLoc();
465     const auto srcTp = getSparseTensorType(op.getInputCoo());
466     const auto dstTp = getSparseTensorType(op);
467 
468     const Value src = adaptor.getInputCoo();
469 
470     NewCallParams params(rewriter, loc);
471     SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, srcTp, src);
472     rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizesValues)
473                                .genNewCall(Action::kSortCOOInPlace, src));
474 
475     return success();
476   }
477 };
478 
479 /// Sparse conversion rule for the dealloc operator.
480 class SparseTensorDeallocConverter
481     : public OpConversionPattern<bufferization::DeallocTensorOp> {
482 public:
483   using OpConversionPattern::OpConversionPattern;
484   LogicalResult
485   matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
486                   ConversionPatternRewriter &rewriter) const override {
487     if (!getSparseTensorType(op.getTensor()).hasEncoding())
488       return failure();
489     StringRef name = "delSparseTensor";
490     createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
491                    EmitCInterface::Off);
492     rewriter.eraseOp(op);
493     return success();
494   }
495 };
496 
497 /// Sparse conversion rule for position accesses.
498 class SparseTensorToPositionsConverter
499     : public OpConversionPattern<ToPositionsOp> {
500 public:
501   using OpConversionPattern::OpConversionPattern;
502   LogicalResult
503   matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
504                   ConversionPatternRewriter &rewriter) const override {
505     auto stt = getSparseTensorType(op.getTensor());
506     auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
507                                  adaptor.getTensor(), op.getLevel());
508     rewriter.replaceOp(op, poss);
509     return success();
510   }
511 };
512 
513 /// Sparse conversion rule for coordinate accesses.
514 class SparseTensorToCoordinatesConverter
515     : public OpConversionPattern<ToCoordinatesOp> {
516 public:
517   using OpConversionPattern::OpConversionPattern;
518   LogicalResult
519   matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
520                   ConversionPatternRewriter &rewriter) const override {
521     auto stt = getSparseTensorType(op.getTensor());
522     auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
523                                    adaptor.getTensor(), op.getLevel());
524     // Cast the MemRef type to the type expected by the users, though these
525     // two types should be compatible at runtime.
526     if (op.getType() != crds.getType())
527       crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
528     rewriter.replaceOp(op, crds);
529     return success();
530   }
531 };
532 
533 /// Sparse conversion rule for value accesses.
534 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
535 public:
536   using OpConversionPattern::OpConversionPattern;
537   LogicalResult
538   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
539                   ConversionPatternRewriter &rewriter) const override {
540     auto stt = getSparseTensorType(op.getTensor());
541     auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
542     rewriter.replaceOp(op, vals);
543     return success();
544   }
545 };
546 
547 /// Sparse conversion rule for number of entries operator.
548 class SparseNumberOfEntriesConverter
549     : public OpConversionPattern<NumberOfEntriesOp> {
550 public:
551   using OpConversionPattern::OpConversionPattern;
552   LogicalResult
553   matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
554                   ConversionPatternRewriter &rewriter) const override {
555     // Query values array size for the actually stored values size.
556     auto stt = getSparseTensorType(op.getTensor());
557     auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
558     auto zero = constantIndex(rewriter, op.getLoc(), 0);
559     rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
560     return success();
561   }
562 };
563 
564 /// Sparse conversion rule for tensor rematerialization.
565 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
566 public:
567   using OpConversionPattern::OpConversionPattern;
568   LogicalResult
569   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
570                   ConversionPatternRewriter &rewriter) const override {
571     if (op.getHasInserts()) {
572       // Finalize any pending insertions.
573       StringRef name = "endLexInsert";
574       createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
575                      EmitCInterface::Off);
576     }
577     rewriter.replaceOp(op, adaptor.getOperands());
578     return success();
579   }
580 };
581 
582 /// Sparse conversion rule for the insertion operator.
583 class SparseTensorInsertConverter
584     : public OpConversionPattern<tensor::InsertOp> {
585 public:
586   using OpConversionPattern::OpConversionPattern;
587   LogicalResult
588   matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
589                   ConversionPatternRewriter &rewriter) const override {
590     // Note that the current regime only allows for strict lexicographic
591     // coordinate order. All values are passed by reference through stack
592     // allocated memrefs.
593     Location loc = op->getLoc();
594     const auto stt = getSparseTensorType(op.getDest());
595 
596     // Dense tensor insertion.
597     if (!stt.hasEncoding())
598       return failure();
599 
600     assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
601     const auto elemTp = stt.getElementType();
602     const Level lvlRank = stt.getLvlRank();
603     Value lvlCoords, vref;
604     {
605       OpBuilder::InsertionGuard guard(rewriter);
606       Operation *loop = op;
607       // Finds the outermost loop.
608       while (auto l = loop->getParentOfType<LoopLikeOpInterface>())
609         loop = l;
610 
611       if (llvm::isa<LoopLikeOpInterface>(loop)) {
612         // Hoists alloca outside the loop to avoid stack overflow.
613         rewriter.setInsertionPoint(loop);
614       }
615       lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
616       vref = genAllocaScalar(rewriter, loc, elemTp);
617     }
618     storeAll(rewriter, loc, lvlCoords, adaptor.getIndices());
619     rewriter.create<memref::StoreOp>(loc, adaptor.getScalar(), vref);
620     SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
621     createFuncCall(rewriter, loc, name, {},
622                    {adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On);
623     rewriter.replaceOp(op, adaptor.getDest());
624     return success();
625   }
626 };
627 
628 /// Sparse conversion rule for the expand operator.
629 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
630 public:
631   using OpConversionPattern::OpConversionPattern;
632   LogicalResult
633   matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
634                   ConversionPatternRewriter &rewriter) const override {
635     Location loc = op->getLoc();
636     const auto srcTp = getSparseTensorType(op.getTensor());
637     Type eltType = srcTp.getElementType();
638     Type boolType = rewriter.getIntegerType(1);
639     Type idxType = rewriter.getIndexType();
640     // All initialization should be done on entry of the loop nest.
641     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
642     // Get the cardinality of valid coordinates for the innermost level.
643     Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
644                                    srcTp.getLvlRank() - 1);
645     // Allocate temporary buffers for values, filled-switch, and coordinates.
646     // We do not use stack buffers for this, since the expanded size may
647     // be rather large (as it envelops a single expanded dense dimension).
648     Value values = genAlloc(rewriter, loc, sz, eltType);
649     Value filled = genAlloc(rewriter, loc, sz, boolType);
650     Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
651     Value zero = constantZero(rewriter, loc, idxType);
652     // Reset the values/filled-switch to all-zero/false. Note that this
653     // introduces an O(N) operation into the computation, but this reset
654     // operation is amortized over the innermost loops for the access
655     // pattern expansion. As noted in the operation doc, we would like
656     // to amortize this setup cost even between kernels.
657     rewriter.create<linalg::FillOp>(
658         loc, ValueRange{constantZero(rewriter, loc, eltType)},
659         ValueRange{values});
660     rewriter.create<linalg::FillOp>(
661         loc, ValueRange{constantZero(rewriter, loc, boolType)},
662         ValueRange{filled});
663     // Replace expansion op with these buffers and initial coordinate.
664     assert(op.getNumResults() == 4);
665     rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
666     return success();
667   }
668 };
669 
670 /// Sparse conversion rule for the compress operator.
671 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
672 public:
673   using OpConversionPattern::OpConversionPattern;
674   LogicalResult
675   matchAndRewrite(CompressOp op, OpAdaptor adaptor,
676                   ConversionPatternRewriter &rewriter) const override {
677     Location loc = op->getLoc();
678     // Note that this method call resets the values/filled-switch back to
679     // all-zero/false by only iterating over the set elements, so the
680     // complexity remains proportional to the sparsity of the expanded
681     // access pattern.
682     Value values = adaptor.getValues();
683     Value filled = adaptor.getFilled();
684     Value added = adaptor.getAdded();
685     Value count = adaptor.getCount();
686     Value tensor = adaptor.getTensor();
687     const auto stt = getSparseTensorType(op.getTensor());
688     const Type elemTp = stt.getElementType();
689     const Level lvlRank = stt.getLvlRank();
690     auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
691     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
692     SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
693     createFuncCall(rewriter, loc, name, {},
694                    {tensor, lvlCoords, values, filled, added, count},
695                    EmitCInterface::On);
696     rewriter.replaceOp(op, adaptor.getTensor());
697     // Deallocate the buffers on exit of the loop nest.
698     Operation *parent = getTop(op);
699     rewriter.setInsertionPointAfter(parent);
700     rewriter.create<memref::DeallocOp>(loc, values);
701     rewriter.create<memref::DeallocOp>(loc, filled);
702     rewriter.create<memref::DeallocOp>(loc, added);
703     return success();
704   }
705 };
706 
707 /// Sparse conversion rule for the sparse_tensor.assemble operator.
708 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
709 public:
710   using OpConversionPattern::OpConversionPattern;
711   LogicalResult
712   matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
713                   ConversionPatternRewriter &rewriter) const override {
714     const Location loc = op->getLoc();
715     const auto dstTp = getSparseTensorType(op.getResult());
716     assert(dstTp.hasStaticDimShape());
717     SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, dstTp);
718     // Use a library method to transfer the external buffers from
719     // clients to the internal SparseTensorStorage. Since we cannot
720     // assume clients transfer ownership of the buffers, this method
721     // will copy all data over into a new SparseTensorStorage.
722     Value dst =
723         NewCallParams(rewriter, loc)
724             .genBuffers(dstTp.withoutDimToLvl(), dimSizesValues)
725             .genNewCall(Action::kPack,
726                         genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
727                                           adaptor.getValues()));
728     rewriter.replaceOp(op, dst);
729     return success();
730   }
731 };
732 
733 /// Sparse conversion rule for the sparse_tensor.disassemble operator.
734 class SparseTensorDisassembleConverter
735     : public OpConversionPattern<DisassembleOp> {
736 public:
737   using OpConversionPattern::OpConversionPattern;
738   LogicalResult
739   matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
740                   ConversionPatternRewriter &rewriter) const override {
741     // We simply expose the buffers to the external client. This
742     // assumes the client only reads the buffers (usually copying it
743     // to the external data structures, such as numpy arrays).
744     Location loc = op->getLoc();
745     auto stt = getSparseTensorType(op.getTensor());
746     SmallVector<Value> retVal;
747     SmallVector<Value> retLen;
748     // Get the positions and coordinates buffers.
749     const Level lvlRank = stt.getLvlRank();
750     Level trailCOOLen = 0;
751     for (Level l = 0; l < lvlRank; l++) {
752       if (!stt.isUniqueLvl(l) &&
753           (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
754         // A `(loose)compressed_nu` level marks the start of trailing COO
755         // start level. Since the target coordinate buffer used for trailing
756         // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
757         // scheme, we cannot simply use the internal buffers.
758         trailCOOLen = lvlRank - l;
759         break;
760       }
761       if (stt.isWithPos(l)) {
762         auto poss =
763             genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
764         auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
765         auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
766         retVal.push_back(poss);
767         retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
768       }
769       if (stt.isWithCrd(l)) {
770         auto crds =
771             genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
772         auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
773         auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
774         retVal.push_back(crds);
775         retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
776       }
777     }
778     // Handle AoS vs. SoA mismatch for COO.
779     if (trailCOOLen != 0) {
780       uint64_t cooStartLvl = lvlRank - trailCOOLen;
781       assert(!stt.isUniqueLvl(cooStartLvl) &&
782              (stt.isCompressedLvl(cooStartLvl) ||
783               stt.isLooseCompressedLvl(cooStartLvl)));
784       // Positions.
785       auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
786                                    cooStartLvl);
787       auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
788       auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
789       retVal.push_back(poss);
790       retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
791       // Coordinates, copied over with:
792       //    for (i = 0; i < crdLen; i++)
793       //       buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
794       auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
795       auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
796                                       cooStartLvl);
797       auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
798                                       cooStartLvl + 1);
799       auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0);
800       auto two = constantIndex(rewriter, loc, 2);
801       auto bufLen = rewriter.create<arith::MulIOp>(loc, crdLen, two);
802       Type indexType = rewriter.getIndexType();
803       auto zero = constantZero(rewriter, loc, indexType);
804       auto one = constantOne(rewriter, loc, indexType);
805       scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, zero, crdLen, one);
806       auto idx = forOp.getInductionVar();
807       rewriter.setInsertionPointToStart(forOp.getBody());
808       auto c0 = rewriter.create<memref::LoadOp>(loc, crds0, idx);
809       auto c1 = rewriter.create<memref::LoadOp>(loc, crds1, idx);
810       SmallVector<Value> args;
811       args.push_back(idx);
812       args.push_back(zero);
813       rewriter.create<memref::StoreOp>(loc, c0, buf, args);
814       args[1] = one;
815       rewriter.create<memref::StoreOp>(loc, c1, buf, args);
816       rewriter.setInsertionPointAfter(forOp);
817       auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
818       retVal.push_back(buf);
819       retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
820     }
821     // Get the values buffer last.
822     auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
823     auto valLenTp = op.getValLen().getType();
824     auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
825     retVal.push_back(vals);
826     retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
827 
828     // Converts MemRefs back to Tensors.
829     assert(retVal.size() + retLen.size() == op.getNumResults());
830     for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
831       auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]);
832       retVal[i] =
833           rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
834     }
835 
836     // Appends the actual memory length used in each buffer returned.
837     retVal.append(retLen.begin(), retLen.end());
838     rewriter.replaceOp(op, retVal);
839     return success();
840   }
841 };
842 
843 struct SparseHasRuntimeLibraryConverter
844     : public OpConversionPattern<HasRuntimeLibraryOp> {
845   using OpConversionPattern::OpConversionPattern;
846   LogicalResult
847   matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
848                   ConversionPatternRewriter &rewriter) const override {
849     auto i1Type = rewriter.getI1Type();
850     rewriter.replaceOpWithNewOp<arith::ConstantOp>(
851         op, i1Type, rewriter.getIntegerAttr(i1Type, 1));
852     return success();
853   }
854 };
855 
856 } // namespace
857 
858 //===----------------------------------------------------------------------===//
859 // Sparse tensor type conversion into opaque pointer.
860 //===----------------------------------------------------------------------===//
861 
862 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
863   addConversion([](Type type) { return type; });
864   addConversion(convertSparseTensorTypes);
865 }
866 
867 //===----------------------------------------------------------------------===//
868 // Public method for populating conversion rules.
869 //===----------------------------------------------------------------------===//
870 
871 /// Populates the given patterns list with conversion rules required for
872 /// the sparsification of linear algebra operations.
873 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
874                                                   RewritePatternSet &patterns) {
875   patterns
876       .add<SparseReturnConverter, SparseTensorLvlOpConverter,
877            SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
878            SparseTensorAllocConverter, SparseTensorEmptyConverter,
879            SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
880            SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
881            SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
882            SparseTensorLoadConverter, SparseTensorInsertConverter,
883            SparseTensorExpandConverter, SparseTensorCompressConverter,
884            SparseTensorAssembleConverter, SparseTensorDisassembleConverter,
885            SparseHasRuntimeLibraryConverter>(typeConverter,
886                                              patterns.getContext());
887 }
888