xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor primitives into calls into a runtime
10 // support library. Sparse tensor types are converted into opaque pointers
11 // to the underlying sparse storage schemes. The use of opaque pointers
12 // together with runtime support library keeps the conversion relatively
13 // simple, but at the expense of IR opacity, which obscures opportunities
14 // for subsequent optimization of the IR. An alternative is provided by
15 // the SparseTensorCodegen pass.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "Utils/CodegenUtils.h"
20 
21 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
23 #include "mlir/Dialect/Linalg/Utils/Utils.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
30 #include "mlir/Dialect/Tensor/IR/Tensor.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 
33 using namespace mlir;
34 using namespace mlir::sparse_tensor;
35 
36 namespace {
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Maps each sparse tensor type to an opaque pointer.
43 static std::optional<Type> convertSparseTensorTypes(Type type) {
44   if (getSparseTensorEncoding(type) != nullptr)
45     return LLVM::LLVMPointerType::get(type.getContext());
46   return std::nullopt;
47 }
48 
49 /// Generates call to lookup a level-size.  N.B., this only generates
50 /// the raw function call, and therefore (intentionally) does not perform
51 /// any dim<->lvl conversion or other logic.
52 static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
53                             uint64_t lvl) {
54   StringRef name = "sparseLvlSize";
55   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
56   Type iTp = builder.getIndexType();
57   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
58       .getResult(0);
59 }
60 
61 /// Generates call to lookup a dimension-size.  N.B., this only generates
62 /// the raw function call, and therefore (intentionally) does not perform
63 /// any dim<->lvl conversion or other logic.
64 static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
65                             uint64_t dim) {
66   StringRef name = "sparseDimSize";
67   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
68   Type iTp = builder.getIndexType();
69   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
70       .getResult(0);
71 }
72 
73 /// Looks up a level-size by returning a statically-computed constant
74 /// (when possible), or by calling `genLvlSizeCall` (when dynamic).
75 static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
76                                  SparseTensorType stt, Value tensor,
77                                  Level lvl) {
78   // Only sparse tensors have "levels" to query.
79   assert(stt.hasEncoding());
80   // TODO: The following implementation only handles permutations;
81   // we'll need to generalize this to handle arbitrary AffineExpr.
82   //
83   // There's no need to assert `isPermutation` here: because
84   // `getDimPosition` checks that the expr isa `AffineDimExpr`,
85   // which is all we care about (for supporting permutations).
86   const Dimension dim =
87       stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
88   const Size sz = stt.getDynamicDimSize(dim);
89   if (!ShapedType::isDynamic(sz))
90     return constantIndex(builder, loc, sz);
91   // If we cannot statically compute the size from the shape, then we
92   // must dynamically query it.  (In principle we could also dynamically
93   // compute it, but since we already did so to construct the `tensor`
94   // in the first place, we might as well query rather than recompute.)
95   return genLvlSizeCall(builder, loc, tensor, lvl);
96 }
97 
98 /// Looks up a dimension-size by returning a constant from the shape
99 /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
100 /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
101 /// of dense tensors).
102 static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
103                                  SparseTensorType stt, Value tensor,
104                                  Dimension dim) {
105   const Size sz = stt.getDynamicDimSize(dim);
106   if (!ShapedType::isDynamic(sz))
107     return constantIndex(builder, loc, sz);
108   if (stt.hasEncoding())
109     return genDimSizeCall(builder, loc, tensor, dim);
110   return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
111 }
112 
113 /// Populates the array with the dimension-sizes of the given tensor.
114 static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
115                          Value tensor, SmallVectorImpl<Value> &out) {
116   const Dimension dimRank = stt.getDimRank();
117   out.clear();
118   out.reserve(dimRank);
119   for (Dimension d = 0; d < dimRank; d++)
120     out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
121 }
122 
123 /// Returns an array with the dimension-sizes of the given tensor.
124 /// If the *tensor* parameters is null, the tensor type is assumed to have a
125 /// static shape.
126 static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
127                                       SparseTensorType stt,
128                                       Value tensor = Value()) {
129   SmallVector<Value> out;
130   fillDimSizes(builder, loc, stt, tensor, out);
131   return out;
132 }
133 
134 /// Generates an uninitialized buffer of the given size and type,
135 /// but returns it as type `memref<? x $tp>` (rather than as type
136 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
137 /// this buffer must be explicitly deallocated by client.
138 static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
139   auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
140   return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
141 }
142 
143 /// Generates a temporary buffer for the level-types of the given encoding.
144 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
145                                SparseTensorType stt) {
146   SmallVector<Value> lvlTypes;
147   lvlTypes.reserve(stt.getLvlRank());
148   for (const auto lt : stt.getEncoding().getLvlTypes())
149     lvlTypes.push_back(constantLevelTypeEncoding(builder, loc, lt));
150   return allocaBuffer(builder, loc, lvlTypes);
151 }
152 
153 /// Extracts the bare (aligned) pointers that point to the tensor.
154 static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
155                                       Value tensor) {
156   auto buf = genToMemref(builder, loc, tensor);
157   return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf);
158 }
159 
160 /// Generates a temporary buffer for the level-types of the given encoding.
161 static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
162                                ValueRange lvlTensors, Value valTensor) {
163   SmallVector<Value> lvlBarePtrs;
164   lvlBarePtrs.reserve(lvlTensors.size() + 1);
165   // Passing in lvl buffer pointers.
166   for (const auto lvl : lvlTensors)
167     lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
168 
169   // Passing in value buffer pointers.
170   lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
171   Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>(
172       loc, allocaBuffer(builder, loc, lvlBarePtrs));
173   Value idxCast =
174       builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr);
175   return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder),
176                                           idxCast);
177 }
178 
179 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
180 /// the "swiss army knife" method of the sparse runtime support library
181 /// for materializing sparse tensors into the computation. This abstraction
182 /// reduces the need for modifications when the API changes.
183 class NewCallParams final {
184 public:
185   /// Allocates the `ValueRange` for the `func::CallOp` parameters.
186   NewCallParams(OpBuilder &builder, Location loc)
187       : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
188 
189   /// Initializes all static parameters (i.e., those which indicate
190   /// type-level information such as the encoding and sizes), generating
191   /// MLIR buffers as needed, and returning `this` for method chaining.
192   NewCallParams &genBuffers(SparseTensorType stt,
193                             ArrayRef<Value> dimSizesValues,
194                             Value dimSizesBuffer = Value()) {
195     assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
196     // Sparsity annotations.
197     params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
198     // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
199     params[kParamDimSizes] = dimSizesBuffer
200                                  ? dimSizesBuffer
201                                  : allocaBuffer(builder, loc, dimSizesValues);
202     SmallVector<Value> lvlSizesValues; // unused
203     params[kParamLvlSizes] = genMapBuffers(
204         builder, loc, stt, dimSizesValues, params[kParamDimSizes],
205         lvlSizesValues, params[kParamDim2Lvl], params[kParamLvl2Dim]);
206     // Secondary and primary types encoding.
207     const auto enc = stt.getEncoding();
208     params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
209     params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc);
210     params[kParamValTp] =
211         constantPrimaryTypeEncoding(builder, loc, stt.getElementType());
212     // Return `this` for method chaining.
213     return *this;
214   }
215 
216   /// Checks whether all the static parameters have been initialized.
217   bool isInitialized() const {
218     for (unsigned i = 0; i < kNumStaticParams; ++i)
219       if (!params[i])
220         return false;
221     return true;
222   }
223 
224   /// Generates a function call, with the current static parameters
225   /// and the given dynamic arguments.
226   Value genNewCall(Action action, Value ptr = Value()) {
227     assert(isInitialized() && "Must initialize before genNewCall");
228     StringRef name = "newSparseTensor";
229     params[kParamAction] = constantAction(builder, loc, action);
230     params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp);
231     return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
232         .getResult(0);
233   }
234 
235 private:
236   static constexpr unsigned kNumStaticParams = 8;
237   static constexpr unsigned kNumDynamicParams = 2;
238   static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
239   static constexpr unsigned kParamDimSizes = 0;
240   static constexpr unsigned kParamLvlSizes = 1;
241   static constexpr unsigned kParamLvlTypes = 2;
242   static constexpr unsigned kParamDim2Lvl = 3;
243   static constexpr unsigned kParamLvl2Dim = 4;
244   static constexpr unsigned kParamPosTp = 5;
245   static constexpr unsigned kParamCrdTp = 6;
246   static constexpr unsigned kParamValTp = 7;
247   static constexpr unsigned kParamAction = 8;
248   static constexpr unsigned kParamPtr = 9;
249 
250   OpBuilder &builder;
251   Location loc;
252   Type pTp;
253   Value params[kNumParams];
254 };
255 
256 /// Generates a call to obtain the values array.
257 static Value genValuesCall(OpBuilder &builder, Location loc,
258                            SparseTensorType stt, Value ptr) {
259   auto eltTp = stt.getElementType();
260   auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
261   SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)};
262   return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
263       .getResult(0);
264 }
265 
266 /// Generates a call to obtain the positions array.
267 static Value genPositionsCall(OpBuilder &builder, Location loc,
268                               SparseTensorType stt, Value ptr, Level l) {
269   Type posTp = stt.getPosType();
270   auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
271   Value lvl = constantIndex(builder, loc, l);
272   SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
273   return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
274                         EmitCInterface::On)
275       .getResult(0);
276 }
277 
278 /// Generates a call to obtain the coordinates array.
279 static Value genCoordinatesCall(OpBuilder &builder, Location loc,
280                                 SparseTensorType stt, Value ptr, Level l) {
281   Type crdTp = stt.getCrdType();
282   auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
283   Value lvl = constantIndex(builder, loc, l);
284   SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)};
285   return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
286                         EmitCInterface::On)
287       .getResult(0);
288 }
289 
290 /// Generates a call to obtain the coordinates array (AoS view).
291 static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc,
292                                       SparseTensorType stt, Value ptr,
293                                       Level l) {
294   Type crdTp = stt.getCrdType();
295   auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
296   Value lvl = constantIndex(builder, loc, l);
297   SmallString<25> name{"sparseCoordinatesBuffer",
298                        overheadTypeFunctionSuffix(crdTp)};
299   return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
300                         EmitCInterface::On)
301       .getResult(0);
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // Conversion rules.
306 //===----------------------------------------------------------------------===//
307 
308 /// Sparse conversion rule for returns.
309 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
310 public:
311   using OpConversionPattern::OpConversionPattern;
312   LogicalResult
313   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
314                   ConversionPatternRewriter &rewriter) const override {
315     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
316     return success();
317   }
318 };
319 
320 /// Sparse conversion rule for accessing level-sizes.
321 class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> {
322 public:
323   using OpConversionPattern::OpConversionPattern;
324   LogicalResult
325   matchAndRewrite(LvlOp op, OpAdaptor adaptor,
326                   ConversionPatternRewriter &rewriter) const override {
327     const auto stt = getSparseTensorType(op.getSource());
328     // Only rewrite sparse DimOp.
329     if (!stt.hasEncoding())
330       return failure();
331 
332     // Only rewrite DimOp with constant index.
333     std::optional<int64_t> lvl = op.getConstantLvlIndex();
334 
335     if (!lvl)
336       return failure();
337 
338     // By now, if the level size is constant, the operation should have already
339     // been folded by LvlOp's folder, so we generate the call unconditionally.
340     Value src = adaptor.getOperands()[0];
341     rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
342     return success();
343   }
344 };
345 
346 /// Sparse conversion rule for trivial tensor casts.
347 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
348 public:
349   using OpConversionPattern::OpConversionPattern;
350   LogicalResult
351   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
352                   ConversionPatternRewriter &rewriter) const override {
353     // Only rewrite identically annotated source/dest.
354     auto encDst = getSparseTensorEncoding(op.getType());
355     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
356     if (!encDst || encDst != encSrc)
357       return failure();
358     rewriter.replaceOp(op, adaptor.getOperands());
359     return success();
360   }
361 };
362 
363 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
364 public:
365   using OpConversionPattern::OpConversionPattern;
366   LogicalResult
367   matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
368                   ConversionPatternRewriter &rewriter) const override {
369     // Simply fold the operation.
370     rewriter.replaceOp(op, adaptor.getSource());
371     return success();
372   }
373 };
374 
375 /// Sparse conversion rule for the new operator.
376 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
377 public:
378   using OpConversionPattern::OpConversionPattern;
379   LogicalResult
380   matchAndRewrite(NewOp op, OpAdaptor adaptor,
381                   ConversionPatternRewriter &rewriter) const override {
382     Location loc = op.getLoc();
383     const auto stt = getSparseTensorType(op);
384     if (!stt.hasEncoding())
385       return failure();
386     // Construct the `reader` opening method calls.
387     SmallVector<Value> dimSizesValues;
388     Value dimSizesBuffer;
389     Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
390                              dimSizesValues, dimSizesBuffer);
391     // Use the `reader` to parse the file.
392     Value tensor = NewCallParams(rewriter, loc)
393                        .genBuffers(stt, dimSizesValues, dimSizesBuffer)
394                        .genNewCall(Action::kFromReader, reader);
395     // Free the memory for `reader`.
396     createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
397                    EmitCInterface::Off);
398     rewriter.replaceOp(op, tensor);
399     return success();
400   }
401 };
402 
403 /// Sparse conversion rule for the alloc operator.
404 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
405 class SparseTensorAllocConverter
406     : public OpConversionPattern<bufferization::AllocTensorOp> {
407 public:
408   using OpConversionPattern::OpConversionPattern;
409   LogicalResult
410   matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
411                   ConversionPatternRewriter &rewriter) const override {
412     const auto stt = getSparseTensorType(op);
413     if (!stt.hasEncoding())
414       return failure();
415     if (op.getCopy())
416       return rewriter.notifyMatchFailure(op, "alloc copy not implemented");
417     // Gather all dimension sizes as SSA values.
418     Location loc = op.getLoc();
419     const Dimension dimRank = stt.getDimRank();
420     SmallVector<Value> dimSizesValues;
421     dimSizesValues.reserve(dimRank);
422     unsigned operandCtr = 0;
423     for (Dimension d = 0; d < dimRank; d++) {
424       dimSizesValues.push_back(
425           stt.isDynamicDim(d)
426               ? adaptor.getOperands()[operandCtr++]
427               : constantIndex(rewriter, loc, op.getStaticSize(d)));
428     }
429     // Generate the call to construct empty tensor. The sizes are
430     // explicitly defined by the arguments to the alloc operator.
431     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
432                                .genBuffers(stt, dimSizesValues)
433                                .genNewCall(Action::kEmpty));
434     return success();
435   }
436 };
437 
438 /// Sparse conversion rule for the empty tensor.
439 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
440 public:
441   using OpConversionPattern::OpConversionPattern;
442   LogicalResult
443   matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
444                   ConversionPatternRewriter &rewriter) const override {
445     Location loc = op.getLoc();
446     const auto stt = getSparseTensorType(op);
447     if (!stt.hasEncoding())
448       return failure();
449     // Gather all dimension sizes as SSA values.
450     const Dimension dimRank = stt.getDimRank();
451     SmallVector<Value> dimSizesValues;
452     dimSizesValues.reserve(dimRank);
453     auto shape = op.getType().getShape();
454     unsigned operandCtr = 0;
455     for (Dimension d = 0; d < dimRank; d++) {
456       dimSizesValues.push_back(stt.isDynamicDim(d)
457                                    ? adaptor.getOperands()[operandCtr++]
458                                    : constantIndex(rewriter, loc, shape[d]));
459     }
460     // Generate the call to construct empty tensor. The sizes are
461     // explicitly defined by the arguments to the alloc operator.
462     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
463                                .genBuffers(stt, dimSizesValues)
464                                .genNewCall(Action::kEmpty));
465     return success();
466   }
467 };
468 
469 /// Sparse conversion rule for the convert operator.
470 class SparseTensorReorderCOOConverter
471     : public OpConversionPattern<ReorderCOOOp> {
472 public:
473   using OpConversionPattern::OpConversionPattern;
474 
475   LogicalResult
476   matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
477                   ConversionPatternRewriter &rewriter) const override {
478     const Location loc = op->getLoc();
479     const auto srcTp = getSparseTensorType(op.getInputCoo());
480     const auto dstTp = getSparseTensorType(op);
481 
482     const Value src = adaptor.getInputCoo();
483 
484     NewCallParams params(rewriter, loc);
485     SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, srcTp, src);
486     rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizesValues)
487                                .genNewCall(Action::kSortCOOInPlace, src));
488 
489     return success();
490   }
491 };
492 
493 /// Sparse conversion rule for the dealloc operator.
494 class SparseTensorDeallocConverter
495     : public OpConversionPattern<bufferization::DeallocTensorOp> {
496 public:
497   using OpConversionPattern::OpConversionPattern;
498   LogicalResult
499   matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
500                   ConversionPatternRewriter &rewriter) const override {
501     if (!getSparseTensorType(op.getTensor()).hasEncoding())
502       return failure();
503     StringRef name = "delSparseTensor";
504     createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
505                    EmitCInterface::Off);
506     rewriter.eraseOp(op);
507     return success();
508   }
509 };
510 
511 /// Sparse conversion rule for position accesses.
512 class SparseTensorToPositionsConverter
513     : public OpConversionPattern<ToPositionsOp> {
514 public:
515   using OpConversionPattern::OpConversionPattern;
516   LogicalResult
517   matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
518                   ConversionPatternRewriter &rewriter) const override {
519     auto stt = getSparseTensorType(op.getTensor());
520     auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
521                                  adaptor.getTensor(), op.getLevel());
522     rewriter.replaceOp(op, poss);
523     return success();
524   }
525 };
526 
527 /// Sparse conversion rule for coordinate accesses.
528 class SparseTensorToCoordinatesConverter
529     : public OpConversionPattern<ToCoordinatesOp> {
530 public:
531   using OpConversionPattern::OpConversionPattern;
532   LogicalResult
533   matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
534                   ConversionPatternRewriter &rewriter) const override {
535     const Location loc = op.getLoc();
536     auto stt = getSparseTensorType(op.getTensor());
537     auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
538                                    op.getLevel());
539     // Cast the MemRef type to the type expected by the users, though these
540     // two types should be compatible at runtime.
541     if (op.getType() != crds.getType())
542       crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
543     rewriter.replaceOp(op, crds);
544     return success();
545   }
546 };
547 
548 /// Sparse conversion rule for coordinate accesses (AoS style).
549 class SparseToCoordinatesBufferConverter
550     : public OpConversionPattern<ToCoordinatesBufferOp> {
551 public:
552   using OpConversionPattern::OpConversionPattern;
553   LogicalResult
554   matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
555                   ConversionPatternRewriter &rewriter) const override {
556     const Location loc = op.getLoc();
557     auto stt = getSparseTensorType(op.getTensor());
558     auto crds = genCoordinatesBufferCall(
559         rewriter, loc, stt, adaptor.getTensor(), stt.getAoSCOOStart());
560     // Cast the MemRef type to the type expected by the users, though these
561     // two types should be compatible at runtime.
562     if (op.getType() != crds.getType())
563       crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
564     rewriter.replaceOp(op, crds);
565     return success();
566   }
567 };
568 
569 /// Sparse conversion rule for value accesses.
570 class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
571 public:
572   using OpConversionPattern::OpConversionPattern;
573   LogicalResult
574   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
575                   ConversionPatternRewriter &rewriter) const override {
576     auto stt = getSparseTensorType(op.getTensor());
577     auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
578     rewriter.replaceOp(op, vals);
579     return success();
580   }
581 };
582 
583 /// Sparse conversion rule for number of entries operator.
584 class SparseNumberOfEntriesConverter
585     : public OpConversionPattern<NumberOfEntriesOp> {
586 public:
587   using OpConversionPattern::OpConversionPattern;
588   LogicalResult
589   matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
590                   ConversionPatternRewriter &rewriter) const override {
591     // Query values array size for the actually stored values size.
592     auto stt = getSparseTensorType(op.getTensor());
593     auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
594     auto zero = constantIndex(rewriter, op.getLoc(), 0);
595     rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
596     return success();
597   }
598 };
599 
600 /// Sparse conversion rule for tensor rematerialization.
601 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
602 public:
603   using OpConversionPattern::OpConversionPattern;
604   LogicalResult
605   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
606                   ConversionPatternRewriter &rewriter) const override {
607     if (op.getHasInserts()) {
608       // Finalize any pending insertions.
609       StringRef name = "endLexInsert";
610       createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
611                      EmitCInterface::Off);
612     }
613     rewriter.replaceOp(op, adaptor.getOperands());
614     return success();
615   }
616 };
617 
618 /// Sparse conversion rule for the insertion operator.
619 class SparseTensorInsertConverter
620     : public OpConversionPattern<tensor::InsertOp> {
621 public:
622   using OpConversionPattern::OpConversionPattern;
623   LogicalResult
624   matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
625                   ConversionPatternRewriter &rewriter) const override {
626     // Note that the current regime only allows for strict lexicographic
627     // coordinate order. All values are passed by reference through stack
628     // allocated memrefs.
629     Location loc = op->getLoc();
630     const auto stt = getSparseTensorType(op.getDest());
631 
632     // Dense tensor insertion.
633     if (!stt.hasEncoding())
634       return failure();
635 
636     assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
637     const auto elemTp = stt.getElementType();
638     const Level lvlRank = stt.getLvlRank();
639     Value lvlCoords, vref;
640     {
641       OpBuilder::InsertionGuard guard(rewriter);
642       Operation *loop = op;
643       // Finds the outermost loop.
644       while (auto l = loop->getParentOfType<LoopLikeOpInterface>())
645         loop = l;
646 
647       if (llvm::isa<LoopLikeOpInterface>(loop)) {
648         // Hoists alloca outside the loop to avoid stack overflow.
649         rewriter.setInsertionPoint(loop);
650       }
651       lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
652       vref = genAllocaScalar(rewriter, loc, elemTp);
653     }
654     storeAll(rewriter, loc, lvlCoords, adaptor.getIndices());
655     rewriter.create<memref::StoreOp>(loc, adaptor.getScalar(), vref);
656     SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
657     createFuncCall(rewriter, loc, name, {},
658                    {adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On);
659     rewriter.replaceOp(op, adaptor.getDest());
660     return success();
661   }
662 };
663 
664 /// Sparse conversion rule for the expand operator.
665 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
666 public:
667   using OpConversionPattern::OpConversionPattern;
668   LogicalResult
669   matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
670                   ConversionPatternRewriter &rewriter) const override {
671     Location loc = op->getLoc();
672     const auto srcTp = getSparseTensorType(op.getTensor());
673     Type eltType = srcTp.getElementType();
674     Type boolType = rewriter.getIntegerType(1);
675     Type idxType = rewriter.getIndexType();
676     // All initialization should be done on entry of the loop nest.
677     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
678     // Get the cardinality of valid coordinates for the innermost level.
679     Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
680                                    srcTp.getLvlRank() - 1);
681     // Allocate temporary buffers for values, filled-switch, and coordinates.
682     // We do not use stack buffers for this, since the expanded size may
683     // be rather large (as it envelops a single expanded dense dimension).
684     Value values = genAlloc(rewriter, loc, sz, eltType);
685     Value filled = genAlloc(rewriter, loc, sz, boolType);
686     Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
687     Value zero = constantZero(rewriter, loc, idxType);
688     // Reset the values/filled-switch to all-zero/false. Note that this
689     // introduces an O(N) operation into the computation, but this reset
690     // operation is amortized over the innermost loops for the access
691     // pattern expansion. As noted in the operation doc, we would like
692     // to amortize this setup cost even between kernels.
693     rewriter.create<linalg::FillOp>(
694         loc, ValueRange{constantZero(rewriter, loc, eltType)},
695         ValueRange{values});
696     rewriter.create<linalg::FillOp>(
697         loc, ValueRange{constantZero(rewriter, loc, boolType)},
698         ValueRange{filled});
699     // Replace expansion op with these buffers and initial coordinate.
700     assert(op.getNumResults() == 4);
701     rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
702     return success();
703   }
704 };
705 
706 /// Sparse conversion rule for the compress operator.
707 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
708 public:
709   using OpConversionPattern::OpConversionPattern;
710   LogicalResult
711   matchAndRewrite(CompressOp op, OpAdaptor adaptor,
712                   ConversionPatternRewriter &rewriter) const override {
713     Location loc = op->getLoc();
714     // Note that this method call resets the values/filled-switch back to
715     // all-zero/false by only iterating over the set elements, so the
716     // complexity remains proportional to the sparsity of the expanded
717     // access pattern.
718     Value values = adaptor.getValues();
719     Value filled = adaptor.getFilled();
720     Value added = adaptor.getAdded();
721     Value count = adaptor.getCount();
722     Value tensor = adaptor.getTensor();
723     const auto stt = getSparseTensorType(op.getTensor());
724     const Type elemTp = stt.getElementType();
725     const Level lvlRank = stt.getLvlRank();
726     auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
727     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
728     SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
729     createFuncCall(rewriter, loc, name, {},
730                    {tensor, lvlCoords, values, filled, added, count},
731                    EmitCInterface::On);
732     rewriter.replaceOp(op, adaptor.getTensor());
733     // Deallocate the buffers on exit of the loop nest.
734     Operation *parent = getTop(op);
735     rewriter.setInsertionPointAfter(parent);
736     rewriter.create<memref::DeallocOp>(loc, values);
737     rewriter.create<memref::DeallocOp>(loc, filled);
738     rewriter.create<memref::DeallocOp>(loc, added);
739     return success();
740   }
741 };
742 
743 /// Sparse conversion rule for the sparse_tensor.assemble operator.
744 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
745 public:
746   using OpConversionPattern::OpConversionPattern;
747   LogicalResult
748   matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
749                   ConversionPatternRewriter &rewriter) const override {
750     const Location loc = op->getLoc();
751     const auto dstTp = getSparseTensorType(op.getResult());
752     assert(dstTp.hasStaticDimShape());
753     SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, dstTp);
754     // Use a library method to transfer the external buffers from
755     // clients to the internal SparseTensorStorage. Since we cannot
756     // assume clients transfer ownership of the buffers, this method
757     // will copy all data over into a new SparseTensorStorage.
758     Value dst =
759         NewCallParams(rewriter, loc)
760             .genBuffers(dstTp.withoutDimToLvl(), dimSizesValues)
761             .genNewCall(Action::kPack,
762                         genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
763                                           adaptor.getValues()));
764     rewriter.replaceOp(op, dst);
765     return success();
766   }
767 };
768 
769 /// Sparse conversion rule for the sparse_tensor.disassemble operator.
770 /// Note that the current implementation simply exposes the buffers to
771 /// the external client. This assumes the client only reads the buffers
772 /// (usually copying it to the external data structures, such as numpy
773 /// arrays). The semantics of the disassemble operation technically
774 /// require that the copying is done here already using the out-levels
775 /// and out-values clause.
776 class SparseTensorDisassembleConverter
777     : public OpConversionPattern<DisassembleOp> {
778 public:
779   using OpConversionPattern::OpConversionPattern;
780   LogicalResult
781   matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
782                   ConversionPatternRewriter &rewriter) const override {
783     Location loc = op->getLoc();
784     auto stt = getSparseTensorType(op.getTensor());
785     SmallVector<Value> retVal;
786     SmallVector<Value> retLen;
787     // Get the positions and coordinates buffers.
788     const Level lvlRank = stt.getLvlRank();
789     Level trailCOOLen = 0;
790     for (Level l = 0; l < lvlRank; l++) {
791       if (!stt.isUniqueLvl(l) &&
792           (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
793         // A `(loose)compressed_nu` level marks the start of trailing COO
794         // start level. Since the target coordinate buffer used for trailing
795         // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
796         // scheme, we cannot simply use the internal buffers.
797         trailCOOLen = lvlRank - l;
798         break;
799       }
800       if (stt.isWithPos(l)) {
801         auto poss =
802             genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
803         auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
804         auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
805         retVal.push_back(poss);
806         retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
807       }
808       if (stt.isWithCrd(l)) {
809         auto crds =
810             genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
811         auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
812         auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
813         retVal.push_back(crds);
814         retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
815       }
816     }
817     // Handle AoS vs. SoA mismatch for COO.
818     if (trailCOOLen != 0) {
819       uint64_t cooStartLvl = lvlRank - trailCOOLen;
820       assert(!stt.isUniqueLvl(cooStartLvl) &&
821              (stt.isCompressedLvl(cooStartLvl) ||
822               stt.isLooseCompressedLvl(cooStartLvl)));
823       // Positions.
824       auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
825                                    cooStartLvl);
826       auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
827       auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
828       retVal.push_back(poss);
829       retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
830       // Coordinates, copied over with:
831       //    for (i = 0; i < crdLen; i++)
832       //       buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
833       auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
834       auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
835                                       cooStartLvl);
836       auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
837                                       cooStartLvl + 1);
838       auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0);
839       auto two = constantIndex(rewriter, loc, 2);
840       auto bufLen = rewriter.create<arith::MulIOp>(loc, crdLen, two);
841       Type indexType = rewriter.getIndexType();
842       auto zero = constantZero(rewriter, loc, indexType);
843       auto one = constantOne(rewriter, loc, indexType);
844       scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, zero, crdLen, one);
845       auto idx = forOp.getInductionVar();
846       rewriter.setInsertionPointToStart(forOp.getBody());
847       auto c0 = rewriter.create<memref::LoadOp>(loc, crds0, idx);
848       auto c1 = rewriter.create<memref::LoadOp>(loc, crds1, idx);
849       SmallVector<Value> args;
850       args.push_back(idx);
851       args.push_back(zero);
852       rewriter.create<memref::StoreOp>(loc, c0, buf, args);
853       args[1] = one;
854       rewriter.create<memref::StoreOp>(loc, c1, buf, args);
855       rewriter.setInsertionPointAfter(forOp);
856       auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
857       retVal.push_back(buf);
858       retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
859     }
860     // Get the values buffer last.
861     auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
862     auto valLenTp = op.getValLen().getType();
863     auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
864     retVal.push_back(vals);
865     retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
866 
867     // Converts MemRefs back to Tensors.
868     assert(retVal.size() + retLen.size() == op.getNumResults());
869     for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
870       auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]);
871       retVal[i] =
872           rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
873     }
874 
875     // Appends the actual memory length used in each buffer returned.
876     retVal.append(retLen.begin(), retLen.end());
877     rewriter.replaceOp(op, retVal);
878     return success();
879   }
880 };
881 
882 struct SparseHasRuntimeLibraryConverter
883     : public OpConversionPattern<HasRuntimeLibraryOp> {
884   using OpConversionPattern::OpConversionPattern;
885   LogicalResult
886   matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
887                   ConversionPatternRewriter &rewriter) const override {
888     auto i1Type = rewriter.getI1Type();
889     rewriter.replaceOpWithNewOp<arith::ConstantOp>(
890         op, i1Type, rewriter.getIntegerAttr(i1Type, 1));
891     return success();
892   }
893 };
894 
895 } // namespace
896 
897 //===----------------------------------------------------------------------===//
898 // Sparse tensor type conversion into opaque pointer.
899 //===----------------------------------------------------------------------===//
900 
901 mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
902   addConversion([](Type type) { return type; });
903   addConversion(convertSparseTensorTypes);
904 }
905 
906 //===----------------------------------------------------------------------===//
907 // Public method for populating conversion rules.
908 //===----------------------------------------------------------------------===//
909 
910 /// Populates the given patterns list with conversion rules required for
911 /// the sparsification of linear algebra operations.
912 void mlir::populateSparseTensorConversionPatterns(
913     const TypeConverter &typeConverter, RewritePatternSet &patterns) {
914   patterns
915       .add<SparseReturnConverter, SparseTensorLvlOpConverter,
916            SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
917            SparseTensorAllocConverter, SparseTensorEmptyConverter,
918            SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
919            SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
920            SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter,
921            SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
922            SparseTensorInsertConverter, SparseTensorExpandConverter,
923            SparseTensorCompressConverter, SparseTensorAssembleConverter,
924            SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
925           typeConverter, patterns.getContext());
926 }
927