xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
121895486Swren romano //===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
2a2c9d4bbSAart Bik //
3a2c9d4bbSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a2c9d4bbSAart Bik // See https://llvm.org/LICENSE.txt for license information.
5a2c9d4bbSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a2c9d4bbSAart Bik //
7a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
8a2c9d4bbSAart Bik //
986b22d31SAart Bik // A pass that converts sparse tensor primitives into calls into a runtime
1086b22d31SAart Bik // support library. Sparse tensor types are converted into opaque pointers
1186b22d31SAart Bik // to the underlying sparse storage schemes. The use of opaque pointers
1286b22d31SAart Bik // together with runtime support library keeps the conversion relatively
1386b22d31SAart Bik // simple, but at the expense of IR opacity, which obscures opportunities
1486b22d31SAart Bik // for subsequent optimization of the IR. An alternative is provided by
1586b22d31SAart Bik // the SparseTensorCodegen pass.
16a2c9d4bbSAart Bik //
17a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
18a2c9d4bbSAart Bik 
19365777ecSAart Bik #include "Utils/CodegenUtils.h"
20efa15f41SAart Bik 
21c66303c2SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
2257470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
23236a9080SAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h"
24a2c9d4bbSAart Bik #include "mlir/Dialect/MemRef/IR/MemRef.h"
258b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
26062f29c8Swren romano #include "mlir/Dialect/SparseTensor/IR/Enums.h"
27a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
28f708a549Swren romano #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
29a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
30ca5d0a73SAart Bik #include "mlir/Dialect/Tensor/IR/Tensor.h"
31a2c9d4bbSAart Bik #include "mlir/Transforms/DialectConversion.h"
32a2c9d4bbSAart Bik 
33a2c9d4bbSAart Bik using namespace mlir;
3496a23911SAart Bik using namespace mlir::sparse_tensor;
35a2c9d4bbSAart Bik 
36a2c9d4bbSAart Bik namespace {
37a2c9d4bbSAart Bik 
3805c7f450SAart Bik //===----------------------------------------------------------------------===//
3905c7f450SAart Bik // Helper methods.
4005c7f450SAart Bik //===----------------------------------------------------------------------===//
4105c7f450SAart Bik 
4286b22d31SAart Bik /// Maps each sparse tensor type to an opaque pointer.
430de16fafSRamkumar Ramachandra static std::optional<Type> convertSparseTensorTypes(Type type) {
4486b22d31SAart Bik   if (getSparseTensorEncoding(type) != nullptr)
45dcae289dSChristian Ulmann     return LLVM::LLVMPointerType::get(type.getContext());
461a36588eSKazu Hirata   return std::nullopt;
4786b22d31SAart Bik }
4886b22d31SAart Bik 
4986f91e45Swren romano /// Generates call to lookup a level-size.  N.B., this only generates
5086f91e45Swren romano /// the raw function call, and therefore (intentionally) does not perform
5186f91e45Swren romano /// any dim<->lvl conversion or other logic.
5286f91e45Swren romano static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
53c518745bSwren romano                             uint64_t lvl) {
54c518745bSwren romano   StringRef name = "sparseLvlSize";
5586f91e45Swren romano   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, lvl)};
56e9fa5590SMatthias Springer   Type iTp = builder.getIndexType();
57ee986ab7SPeiming Liu   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
58d8731bfcSwren romano       .getResult(0);
599d1db3d4SAart Bik }
609d1db3d4SAart Bik 
6186f91e45Swren romano /// Generates call to lookup a dimension-size.  N.B., this only generates
6286f91e45Swren romano /// the raw function call, and therefore (intentionally) does not perform
6386f91e45Swren romano /// any dim<->lvl conversion or other logic.
6486f91e45Swren romano static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
6586f91e45Swren romano                             uint64_t dim) {
6686f91e45Swren romano   StringRef name = "sparseDimSize";
6786f91e45Swren romano   SmallVector<Value, 2> params{tensor, constantIndex(builder, loc, dim)};
6886f91e45Swren romano   Type iTp = builder.getIndexType();
6986f91e45Swren romano   return createFuncCall(builder, loc, name, iTp, params, EmitCInterface::Off)
7086f91e45Swren romano       .getResult(0);
7186f91e45Swren romano }
7286f91e45Swren romano 
7386f91e45Swren romano /// Looks up a level-size by returning a statically-computed constant
7486f91e45Swren romano /// (when possible), or by calling `genLvlSizeCall` (when dynamic).
7586f91e45Swren romano static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
76f708a549Swren romano                                  SparseTensorType stt, Value tensor,
77f708a549Swren romano                                  Level lvl) {
7886f91e45Swren romano   // Only sparse tensors have "levels" to query.
79f708a549Swren romano   assert(stt.hasEncoding());
8086f91e45Swren romano   // TODO: The following implementation only handles permutations;
8186f91e45Swren romano   // we'll need to generalize this to handle arbitrary AffineExpr.
8286f91e45Swren romano   //
8386f91e45Swren romano   // There's no need to assert `isPermutation` here: because
8486f91e45Swren romano   // `getDimPosition` checks that the expr isa `AffineDimExpr`,
8586f91e45Swren romano   // which is all we care about (for supporting permutations).
86f708a549Swren romano   const Dimension dim =
8776647fceSwren romano       stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
8822212ca7SAart Bik   const Size sz = stt.getDynamicDimSize(dim);
8922212ca7SAart Bik   if (!ShapedType::isDynamic(sz))
9022212ca7SAart Bik     return constantIndex(builder, loc, sz);
9186f91e45Swren romano   // If we cannot statically compute the size from the shape, then we
9286f91e45Swren romano   // must dynamically query it.  (In principle we could also dynamically
9386f91e45Swren romano   // compute it, but since we already did so to construct the `tensor`
9486f91e45Swren romano   // in the first place, we might as well query rather than recompute.)
9586f91e45Swren romano   return genLvlSizeCall(builder, loc, tensor, lvl);
96c248219bSPeiming Liu }
97c248219bSPeiming Liu 
9886f91e45Swren romano /// Looks up a dimension-size by returning a constant from the shape
9986f91e45Swren romano /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
10086f91e45Swren romano /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
10186f91e45Swren romano /// of dense tensors).
10286f91e45Swren romano static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
103f708a549Swren romano                                  SparseTensorType stt, Value tensor,
104f708a549Swren romano                                  Dimension dim) {
10522212ca7SAart Bik   const Size sz = stt.getDynamicDimSize(dim);
10622212ca7SAart Bik   if (!ShapedType::isDynamic(sz))
10722212ca7SAart Bik     return constantIndex(builder, loc, sz);
108f708a549Swren romano   if (stt.hasEncoding())
10986f91e45Swren romano     return genDimSizeCall(builder, loc, tensor, dim);
11086f91e45Swren romano   return linalg::createOrFoldDimOp(builder, loc, tensor, dim);
111c248219bSPeiming Liu }
112c248219bSPeiming Liu 
11386f91e45Swren romano /// Populates the array with the dimension-sizes of the given tensor.
114f708a549Swren romano static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
11586f91e45Swren romano                          Value tensor, SmallVectorImpl<Value> &out) {
116f708a549Swren romano   const Dimension dimRank = stt.getDimRank();
117f708a549Swren romano   out.clear();
11886f91e45Swren romano   out.reserve(dimRank);
119f708a549Swren romano   for (Dimension d = 0; d < dimRank; d++)
120f708a549Swren romano     out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
1219d1db3d4SAart Bik }
12286f91e45Swren romano 
12386f91e45Swren romano /// Returns an array with the dimension-sizes of the given tensor.
124fa6726e2SPeiming Liu /// If the *tensor* parameters is null, the tensor type is assumed to have a
125fa6726e2SPeiming Liu /// static shape.
12686f91e45Swren romano static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
127fa6726e2SPeiming Liu                                       SparseTensorType stt,
128fa6726e2SPeiming Liu                                       Value tensor = Value()) {
12986f91e45Swren romano   SmallVector<Value> out;
130f708a549Swren romano   fillDimSizes(builder, loc, stt, tensor, out);
13186f91e45Swren romano   return out;
13286f91e45Swren romano }
13386f91e45Swren romano 
1340b55f94dSAart Bik /// Generates an uninitialized buffer of the given size and type,
1350b55f94dSAart Bik /// but returns it as type `memref<? x $tp>` (rather than as type
1360b55f94dSAart Bik /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
1370b55f94dSAart Bik /// this buffer must be explicitly deallocated by client.
138e9fa5590SMatthias Springer static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
139399638f9SAliia Khasanova   auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
1400b55f94dSAart Bik   return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
1410b55f94dSAart Bik }
1420b55f94dSAart Bik 
1432af2e4dbSwren romano /// Generates a temporary buffer for the level-types of the given encoding.
1442af2e4dbSwren romano static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
145f708a549Swren romano                                SparseTensorType stt) {
1462af2e4dbSwren romano   SmallVector<Value> lvlTypes;
147f708a549Swren romano   lvlTypes.reserve(stt.getLvlRank());
1481dd387e1SAart Bik   for (const auto lt : stt.getEncoding().getLvlTypes())
1491944c4f7SAart Bik     lvlTypes.push_back(constantLevelTypeEncoding(builder, loc, lt));
15062896428Swren romano   return allocaBuffer(builder, loc, lvlTypes);
1512af2e4dbSwren romano }
1522af2e4dbSwren romano 
153fa6726e2SPeiming Liu /// Extracts the bare (aligned) pointers that point to the tensor.
154fa6726e2SPeiming Liu static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
155fa6726e2SPeiming Liu                                       Value tensor) {
156fa6726e2SPeiming Liu   auto buf = genToMemref(builder, loc, tensor);
157fa6726e2SPeiming Liu   return builder.create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf);
158fa6726e2SPeiming Liu }
159fa6726e2SPeiming Liu 
160fa6726e2SPeiming Liu /// Generates a temporary buffer for the level-types of the given encoding.
161fa6726e2SPeiming Liu static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
162fa6726e2SPeiming Liu                                ValueRange lvlTensors, Value valTensor) {
163fa6726e2SPeiming Liu   SmallVector<Value> lvlBarePtrs;
164fa6726e2SPeiming Liu   lvlBarePtrs.reserve(lvlTensors.size() + 1);
165fa6726e2SPeiming Liu   // Passing in lvl buffer pointers.
166fa6726e2SPeiming Liu   for (const auto lvl : lvlTensors)
167fa6726e2SPeiming Liu     lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
168fa6726e2SPeiming Liu 
169fa6726e2SPeiming Liu   // Passing in value buffer pointers.
170fa6726e2SPeiming Liu   lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
171fa6726e2SPeiming Liu   Value idxPtr = builder.create<memref::ExtractAlignedPointerAsIndexOp>(
172fa6726e2SPeiming Liu       loc, allocaBuffer(builder, loc, lvlBarePtrs));
173fa6726e2SPeiming Liu   Value idxCast =
174fa6726e2SPeiming Liu       builder.create<arith::IndexCastOp>(loc, builder.getI64Type(), idxPtr);
175fa6726e2SPeiming Liu   return builder.create<LLVM::IntToPtrOp>(loc, getOpaquePointerType(builder),
176fa6726e2SPeiming Liu                                           idxCast);
177fa6726e2SPeiming Liu }
178fa6726e2SPeiming Liu 
179f5ce99afSwren romano /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
180f5ce99afSwren romano /// the "swiss army knife" method of the sparse runtime support library
181f5ce99afSwren romano /// for materializing sparse tensors into the computation. This abstraction
182b7188d28SAart Bik /// reduces the need for modifications when the API changes.
183f5ce99afSwren romano class NewCallParams final {
184f5ce99afSwren romano public:
185b7188d28SAart Bik   /// Allocates the `ValueRange` for the `func::CallOp` parameters.
186f5ce99afSwren romano   NewCallParams(OpBuilder &builder, Location loc)
187f5ce99afSwren romano       : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
188f5ce99afSwren romano 
189f5ce99afSwren romano   /// Initializes all static parameters (i.e., those which indicate
190f5ce99afSwren romano   /// type-level information such as the encoding and sizes), generating
191f5ce99afSwren romano   /// MLIR buffers as needed, and returning `this` for method chaining.
192b7188d28SAart Bik   NewCallParams &genBuffers(SparseTensorType stt,
193d392073fSAart Bik                             ArrayRef<Value> dimSizesValues,
194d392073fSAart Bik                             Value dimSizesBuffer = Value()) {
195a942f7c8SAart Bik     assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
196b7188d28SAart Bik     // Sparsity annotations.
197b7188d28SAart Bik     params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
198b7188d28SAart Bik     // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
199d392073fSAart Bik     params[kParamDimSizes] = dimSizesBuffer
200d392073fSAart Bik                                  ? dimSizesBuffer
201d392073fSAart Bik                                  : allocaBuffer(builder, loc, dimSizesValues);
2022323f48eSAart Bik     SmallVector<Value> lvlSizesValues; // unused
2032323f48eSAart Bik     params[kParamLvlSizes] = genMapBuffers(
2042323f48eSAart Bik         builder, loc, stt, dimSizesValues, params[kParamDimSizes],
2052323f48eSAart Bik         lvlSizesValues, params[kParamDim2Lvl], params[kParamLvl2Dim]);
206b7188d28SAart Bik     // Secondary and primary types encoding.
207f708a549Swren romano     const auto enc = stt.getEncoding();
20884cd51bbSwren romano     params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
20984cd51bbSwren romano     params[kParamCrdTp] = constantCrdTypeEncoding(builder, loc, enc);
210f5ce99afSwren romano     params[kParamValTp] =
211f708a549Swren romano         constantPrimaryTypeEncoding(builder, loc, stt.getElementType());
212bbecd422SAart Bik     // Return `this` for method chaining.
213f5ce99afSwren romano     return *this;
214f5ce99afSwren romano   }
215f5ce99afSwren romano 
216f5ce99afSwren romano   /// Checks whether all the static parameters have been initialized.
217f5ce99afSwren romano   bool isInitialized() const {
218f5ce99afSwren romano     for (unsigned i = 0; i < kNumStaticParams; ++i)
219f5ce99afSwren romano       if (!params[i])
220f5ce99afSwren romano         return false;
221f5ce99afSwren romano     return true;
222f5ce99afSwren romano   }
223f5ce99afSwren romano 
224f5ce99afSwren romano   /// Generates a function call, with the current static parameters
225f5ce99afSwren romano   /// and the given dynamic arguments.
226f5ce99afSwren romano   Value genNewCall(Action action, Value ptr = Value()) {
227f5ce99afSwren romano     assert(isInitialized() && "Must initialize before genNewCall");
228f5ce99afSwren romano     StringRef name = "newSparseTensor";
229f5ce99afSwren romano     params[kParamAction] = constantAction(builder, loc, action);
23085175eddSTobias Gysi     params[kParamPtr] = ptr ? ptr : builder.create<LLVM::ZeroOp>(loc, pTp);
231f5ce99afSwren romano     return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
232f5ce99afSwren romano         .getResult(0);
233f5ce99afSwren romano   }
234f5ce99afSwren romano 
235f5ce99afSwren romano private:
236c518745bSwren romano   static constexpr unsigned kNumStaticParams = 8;
237f5ce99afSwren romano   static constexpr unsigned kNumDynamicParams = 2;
238f5ce99afSwren romano   static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
239c518745bSwren romano   static constexpr unsigned kParamDimSizes = 0;
240c518745bSwren romano   static constexpr unsigned kParamLvlSizes = 1;
241c518745bSwren romano   static constexpr unsigned kParamLvlTypes = 2;
242b7188d28SAart Bik   static constexpr unsigned kParamDim2Lvl = 3;
243b7188d28SAart Bik   static constexpr unsigned kParamLvl2Dim = 4;
24484cd51bbSwren romano   static constexpr unsigned kParamPosTp = 5;
24584cd51bbSwren romano   static constexpr unsigned kParamCrdTp = 6;
246c518745bSwren romano   static constexpr unsigned kParamValTp = 7;
247c518745bSwren romano   static constexpr unsigned kParamAction = 8;
248c518745bSwren romano   static constexpr unsigned kParamPtr = 9;
249f5ce99afSwren romano 
250f5ce99afSwren romano   OpBuilder &builder;
251f5ce99afSwren romano   Location loc;
252f5ce99afSwren romano   Type pTp;
253f5ce99afSwren romano   Value params[kNumParams];
254f5ce99afSwren romano };
255f5ce99afSwren romano 
2560f3e4d1aSAart Bik /// Generates a call to obtain the values array.
257af8428c0SAart Bik static Value genValuesCall(OpBuilder &builder, Location loc,
258af8428c0SAart Bik                            SparseTensorType stt, Value ptr) {
259af8428c0SAart Bik   auto eltTp = stt.getElementType();
260af8428c0SAart Bik   auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
261af8428c0SAart Bik   SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)};
262af8428c0SAart Bik   return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
263af8428c0SAart Bik       .getResult(0);
264af8428c0SAart Bik }
265af8428c0SAart Bik 
266af8428c0SAart Bik /// Generates a call to obtain the positions array.
267af8428c0SAart Bik static Value genPositionsCall(OpBuilder &builder, Location loc,
268af8428c0SAart Bik                               SparseTensorType stt, Value ptr, Level l) {
269af8428c0SAart Bik   Type posTp = stt.getPosType();
270af8428c0SAart Bik   auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
271af8428c0SAart Bik   Value lvl = constantIndex(builder, loc, l);
272af8428c0SAart Bik   SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
273af8428c0SAart Bik   return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
274af8428c0SAart Bik                         EmitCInterface::On)
275af8428c0SAart Bik       .getResult(0);
276af8428c0SAart Bik }
277af8428c0SAart Bik 
278dc4cfdbbSAart Bik /// Generates a call to obtain the coordinates array.
279af8428c0SAart Bik static Value genCoordinatesCall(OpBuilder &builder, Location loc,
280af8428c0SAart Bik                                 SparseTensorType stt, Value ptr, Level l) {
281af8428c0SAart Bik   Type crdTp = stt.getCrdType();
282af8428c0SAart Bik   auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
283af8428c0SAart Bik   Value lvl = constantIndex(builder, loc, l);
284af8428c0SAart Bik   SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)};
285af8428c0SAart Bik   return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
286af8428c0SAart Bik                         EmitCInterface::On)
2870f3e4d1aSAart Bik       .getResult(0);
2880f3e4d1aSAart Bik }
2890f3e4d1aSAart Bik 
290dc4cfdbbSAart Bik /// Generates a call to obtain the coordinates array (AoS view).
291dc4cfdbbSAart Bik static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc,
292dc4cfdbbSAart Bik                                       SparseTensorType stt, Value ptr,
293dc4cfdbbSAart Bik                                       Level l) {
294dc4cfdbbSAart Bik   Type crdTp = stt.getCrdType();
295dc4cfdbbSAart Bik   auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
296dc4cfdbbSAart Bik   Value lvl = constantIndex(builder, loc, l);
297dc4cfdbbSAart Bik   SmallString<25> name{"sparseCoordinatesBuffer",
298dc4cfdbbSAart Bik                        overheadTypeFunctionSuffix(crdTp)};
299dc4cfdbbSAart Bik   return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
300dc4cfdbbSAart Bik                         EmitCInterface::On)
301dc4cfdbbSAart Bik       .getResult(0);
302dc4cfdbbSAart Bik }
303dc4cfdbbSAart Bik 
30405c7f450SAart Bik //===----------------------------------------------------------------------===//
30505c7f450SAart Bik // Conversion rules.
30605c7f450SAart Bik //===----------------------------------------------------------------------===//
30705c7f450SAart Bik 
30896a23911SAart Bik /// Sparse conversion rule for returns.
30923aa5a74SRiver Riddle class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
31096a23911SAart Bik public:
311a2c9d4bbSAart Bik   using OpConversionPattern::OpConversionPattern;
312a2c9d4bbSAart Bik   LogicalResult
31323aa5a74SRiver Riddle   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
314a2c9d4bbSAart Bik                   ConversionPatternRewriter &rewriter) const override {
31523aa5a74SRiver Riddle     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
316a2c9d4bbSAart Bik     return success();
317a2c9d4bbSAart Bik   }
318a2c9d4bbSAart Bik };
319a2c9d4bbSAart Bik 
320c780352dSPeiming Liu /// Sparse conversion rule for accessing level-sizes.
321c780352dSPeiming Liu class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> {
322a2c9d4bbSAart Bik public:
323a2c9d4bbSAart Bik   using OpConversionPattern::OpConversionPattern;
324a2c9d4bbSAart Bik   LogicalResult
325c780352dSPeiming Liu   matchAndRewrite(LvlOp op, OpAdaptor adaptor,
326a2c9d4bbSAart Bik                   ConversionPatternRewriter &rewriter) const override {
327f708a549Swren romano     const auto stt = getSparseTensorType(op.getSource());
32886f91e45Swren romano     // Only rewrite sparse DimOp.
329f708a549Swren romano     if (!stt.hasEncoding())
330d37d72eaSAart Bik       return failure();
331c780352dSPeiming Liu 
33286f91e45Swren romano     // Only rewrite DimOp with constant index.
333c780352dSPeiming Liu     std::optional<int64_t> lvl = op.getConstantLvlIndex();
334c780352dSPeiming Liu 
335c780352dSPeiming Liu     if (!lvl)
336d37d72eaSAart Bik       return failure();
337c780352dSPeiming Liu 
338c780352dSPeiming Liu     // By now, if the level size is constant, the operation should have already
339c780352dSPeiming Liu     // been folded by LvlOp's folder, so we generate the call unconditionally.
3409d1db3d4SAart Bik     Value src = adaptor.getOperands()[0];
341c780352dSPeiming Liu     rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
342a2c9d4bbSAart Bik     return success();
343a2c9d4bbSAart Bik   }
344a2c9d4bbSAart Bik };
345a2c9d4bbSAart Bik 
3461b15160eSAart Bik /// Sparse conversion rule for trivial tensor casts.
3471b15160eSAart Bik class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
348faa00c13SAart Bik public:
3491b15160eSAart Bik   using OpConversionPattern::OpConversionPattern;
3501b15160eSAart Bik   LogicalResult
3511b15160eSAart Bik   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
3521b15160eSAart Bik                   ConversionPatternRewriter &rewriter) const override {
3531b15160eSAart Bik     // Only rewrite identically annotated source/dest.
3541b15160eSAart Bik     auto encDst = getSparseTensorEncoding(op.getType());
3558df54a6aSJacques Pienaar     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
3561b15160eSAart Bik     if (!encDst || encDst != encSrc)
3571b15160eSAart Bik       return failure();
3581b15160eSAart Bik     rewriter.replaceOp(op, adaptor.getOperands());
3591b15160eSAart Bik     return success();
3601b15160eSAart Bik   }
3611b15160eSAart Bik };
3621b15160eSAart Bik 
363ef222988SPeiming Liu class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
364ef222988SPeiming Liu public:
365ef222988SPeiming Liu   using OpConversionPattern::OpConversionPattern;
366ef222988SPeiming Liu   LogicalResult
367ef222988SPeiming Liu   matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
368ef222988SPeiming Liu                   ConversionPatternRewriter &rewriter) const override {
369ef222988SPeiming Liu     // Simply fold the operation.
370ef222988SPeiming Liu     rewriter.replaceOp(op, adaptor.getSource());
371ef222988SPeiming Liu     return success();
372ef222988SPeiming Liu   }
373ef222988SPeiming Liu };
374ef222988SPeiming Liu 
37596a23911SAart Bik /// Sparse conversion rule for the new operator.
37696a23911SAart Bik class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
377faa00c13SAart Bik public:
37896a23911SAart Bik   using OpConversionPattern::OpConversionPattern;
37996a23911SAart Bik   LogicalResult
380b54c724bSRiver Riddle   matchAndRewrite(NewOp op, OpAdaptor adaptor,
38196a23911SAart Bik                   ConversionPatternRewriter &rewriter) const override {
382ee986ab7SPeiming Liu     Location loc = op.getLoc();
383f708a549Swren romano     const auto stt = getSparseTensorType(op);
384f708a549Swren romano     if (!stt.hasEncoding())
38596a23911SAart Bik       return failure();
386d392073fSAart Bik     // Construct the `reader` opening method calls.
3872323f48eSAart Bik     SmallVector<Value> dimSizesValues;
3882af2e4dbSwren romano     Value dimSizesBuffer;
389d3af6535SAart Bik     Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
3902323f48eSAart Bik                              dimSizesValues, dimSizesBuffer);
3912af2e4dbSwren romano     // Use the `reader` to parse the file.
392d392073fSAart Bik     Value tensor = NewCallParams(rewriter, loc)
3932323f48eSAart Bik                        .genBuffers(stt, dimSizesValues, dimSizesBuffer)
394d392073fSAart Bik                        .genNewCall(Action::kFromReader, reader);
3952af2e4dbSwren romano     // Free the memory for `reader`.
3962af2e4dbSwren romano     createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
3972af2e4dbSwren romano                    EmitCInterface::Off);
3982af2e4dbSwren romano     rewriter.replaceOp(op, tensor);
399b24788abSAart Bik     return success();
400b24788abSAart Bik   }
401b24788abSAart Bik };
402b24788abSAart Bik 
4036232a8f3SMatthias Springer /// Sparse conversion rule for the alloc operator.
404c6472f57SAart Bik /// TODO(springerm): remove when bufferization.alloc_tensor is gone
4056232a8f3SMatthias Springer class SparseTensorAllocConverter
4066232a8f3SMatthias Springer     : public OpConversionPattern<bufferization::AllocTensorOp> {
407faa00c13SAart Bik public:
408b24788abSAart Bik   using OpConversionPattern::OpConversionPattern;
409b24788abSAart Bik   LogicalResult
4106232a8f3SMatthias Springer   matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
411b24788abSAart Bik                   ConversionPatternRewriter &rewriter) const override {
412f708a549Swren romano     const auto stt = getSparseTensorType(op);
413f708a549Swren romano     if (!stt.hasEncoding())
414b24788abSAart Bik       return failure();
415160d483bSAart Bik     if (op.getCopy())
416160d483bSAart Bik       return rewriter.notifyMatchFailure(op, "alloc copy not implemented");
4176232a8f3SMatthias Springer     // Gather all dimension sizes as SSA values.
418160d483bSAart Bik     Location loc = op.getLoc();
419f708a549Swren romano     const Dimension dimRank = stt.getDimRank();
4202323f48eSAart Bik     SmallVector<Value> dimSizesValues;
4212323f48eSAart Bik     dimSizesValues.reserve(dimRank);
422f708a549Swren romano     unsigned operandCtr = 0;
423af8428c0SAart Bik     for (Dimension d = 0; d < dimRank; d++) {
4242323f48eSAart Bik       dimSizesValues.push_back(
425f708a549Swren romano           stt.isDynamicDim(d)
426f708a549Swren romano               ? adaptor.getOperands()[operandCtr++]
427f708a549Swren romano               : constantIndex(rewriter, loc, op.getStaticSize(d)));
4286232a8f3SMatthias Springer     }
4299d1db3d4SAart Bik     // Generate the call to construct empty tensor. The sizes are
4306232a8f3SMatthias Springer     // explicitly defined by the arguments to the alloc operator.
431f708a549Swren romano     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
4322323f48eSAart Bik                                .genBuffers(stt, dimSizesValues)
433f5ce99afSwren romano                                .genNewCall(Action::kEmpty));
43496a23911SAart Bik     return success();
43596a23911SAart Bik   }
43696a23911SAart Bik };
43796a23911SAart Bik 
438c6472f57SAart Bik /// Sparse conversion rule for the empty tensor.
439c6472f57SAart Bik class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
440c6472f57SAart Bik public:
441c6472f57SAart Bik   using OpConversionPattern::OpConversionPattern;
442c6472f57SAart Bik   LogicalResult
443c6472f57SAart Bik   matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
444c6472f57SAart Bik                   ConversionPatternRewriter &rewriter) const override {
445c6472f57SAart Bik     Location loc = op.getLoc();
446c6472f57SAart Bik     const auto stt = getSparseTensorType(op);
447c6472f57SAart Bik     if (!stt.hasEncoding())
448c6472f57SAart Bik       return failure();
449c6472f57SAart Bik     // Gather all dimension sizes as SSA values.
450c6472f57SAart Bik     const Dimension dimRank = stt.getDimRank();
4512323f48eSAart Bik     SmallVector<Value> dimSizesValues;
4522323f48eSAart Bik     dimSizesValues.reserve(dimRank);
453c6472f57SAart Bik     auto shape = op.getType().getShape();
454c6472f57SAart Bik     unsigned operandCtr = 0;
455af8428c0SAart Bik     for (Dimension d = 0; d < dimRank; d++) {
4562323f48eSAart Bik       dimSizesValues.push_back(stt.isDynamicDim(d)
457c6472f57SAart Bik                                    ? adaptor.getOperands()[operandCtr++]
458c6472f57SAart Bik                                    : constantIndex(rewriter, loc, shape[d]));
459c6472f57SAart Bik     }
460c6472f57SAart Bik     // Generate the call to construct empty tensor. The sizes are
461c6472f57SAart Bik     // explicitly defined by the arguments to the alloc operator.
462c6472f57SAart Bik     rewriter.replaceOp(op, NewCallParams(rewriter, loc)
4632323f48eSAart Bik                                .genBuffers(stt, dimSizesValues)
464c6472f57SAart Bik                                .genNewCall(Action::kEmpty));
465c6472f57SAart Bik     return success();
466c6472f57SAart Bik   }
467c6472f57SAart Bik };
468c6472f57SAart Bik 
469697ea09dSAart Bik /// Sparse conversion rule for the convert operator.
470f248d0b2SPeiming Liu class SparseTensorReorderCOOConverter
471f248d0b2SPeiming Liu     : public OpConversionPattern<ReorderCOOOp> {
472c7e24db4Swren romano public:
473697ea09dSAart Bik   using OpConversionPattern::OpConversionPattern;
474c7e24db4Swren romano 
475697ea09dSAart Bik   LogicalResult
476f248d0b2SPeiming Liu   matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
477697ea09dSAart Bik                   ConversionPatternRewriter &rewriter) const override {
478f708a549Swren romano     const Location loc = op->getLoc();
479f248d0b2SPeiming Liu     const auto srcTp = getSparseTensorType(op.getInputCoo());
480f708a549Swren romano     const auto dstTp = getSparseTensorType(op);
481f708a549Swren romano 
482f248d0b2SPeiming Liu     const Value src = adaptor.getInputCoo();
483f248d0b2SPeiming Liu 
484f5ce99afSwren romano     NewCallParams params(rewriter, loc);
4852323f48eSAart Bik     SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, srcTp, src);
4862323f48eSAart Bik     rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizesValues)
487f248d0b2SPeiming Liu                                .genNewCall(Action::kSortCOOInPlace, src));
488faa00c13SAart Bik 
489f248d0b2SPeiming Liu     return success();
490f248d0b2SPeiming Liu   }
491697ea09dSAart Bik };
492697ea09dSAart Bik 
49327a431f5SMatthias Springer /// Sparse conversion rule for the dealloc operator.
49427a431f5SMatthias Springer class SparseTensorDeallocConverter
49527a431f5SMatthias Springer     : public OpConversionPattern<bufferization::DeallocTensorOp> {
49616b8f4ddSAart Bik public:
49716b8f4ddSAart Bik   using OpConversionPattern::OpConversionPattern;
49816b8f4ddSAart Bik   LogicalResult
49927a431f5SMatthias Springer   matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
50016b8f4ddSAart Bik                   ConversionPatternRewriter &rewriter) const override {
501f708a549Swren romano     if (!getSparseTensorType(op.getTensor()).hasEncoding())
50227a431f5SMatthias Springer       return failure();
50316b8f4ddSAart Bik     StringRef name = "delSparseTensor";
504ee986ab7SPeiming Liu     createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
505d8731bfcSwren romano                    EmitCInterface::Off);
50616b8f4ddSAart Bik     rewriter.eraseOp(op);
50716b8f4ddSAart Bik     return success();
50816b8f4ddSAart Bik   }
50916b8f4ddSAart Bik };
51016b8f4ddSAart Bik 
51184cd51bbSwren romano /// Sparse conversion rule for position accesses.
51284cd51bbSwren romano class SparseTensorToPositionsConverter
51384cd51bbSwren romano     : public OpConversionPattern<ToPositionsOp> {
514a2c9d4bbSAart Bik public:
515a2c9d4bbSAart Bik   using OpConversionPattern::OpConversionPattern;
516a2c9d4bbSAart Bik   LogicalResult
51784cd51bbSwren romano   matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
518a2c9d4bbSAart Bik                   ConversionPatternRewriter &rewriter) const override {
519af8428c0SAart Bik     auto stt = getSparseTensorType(op.getTensor());
520af8428c0SAart Bik     auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
521af8428c0SAart Bik                                  adaptor.getTensor(), op.getLevel());
522af8428c0SAart Bik     rewriter.replaceOp(op, poss);
523a2c9d4bbSAart Bik     return success();
524a2c9d4bbSAart Bik   }
525a2c9d4bbSAart Bik };
526a2c9d4bbSAart Bik 
52784cd51bbSwren romano /// Sparse conversion rule for coordinate accesses.
52884cd51bbSwren romano class SparseTensorToCoordinatesConverter
52984cd51bbSwren romano     : public OpConversionPattern<ToCoordinatesOp> {
530a2c9d4bbSAart Bik public:
531a2c9d4bbSAart Bik   using OpConversionPattern::OpConversionPattern;
532a2c9d4bbSAart Bik   LogicalResult
53384cd51bbSwren romano   matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
534a2c9d4bbSAart Bik                   ConversionPatternRewriter &rewriter) const override {
535dc4cfdbbSAart Bik     const Location loc = op.getLoc();
536af8428c0SAart Bik     auto stt = getSparseTensorType(op.getTensor());
537dc4cfdbbSAart Bik     auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
538dc4cfdbbSAart Bik                                    op.getLevel());
53990aa4362Sbixia1     // Cast the MemRef type to the type expected by the users, though these
54090aa4362Sbixia1     // two types should be compatible at runtime.
541af8428c0SAart Bik     if (op.getType() != crds.getType())
542dc4cfdbbSAart Bik       crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
543dc4cfdbbSAart Bik     rewriter.replaceOp(op, crds);
544dc4cfdbbSAart Bik     return success();
545dc4cfdbbSAart Bik   }
546dc4cfdbbSAart Bik };
547dc4cfdbbSAart Bik 
548dc4cfdbbSAart Bik /// Sparse conversion rule for coordinate accesses (AoS style).
549dc4cfdbbSAart Bik class SparseToCoordinatesBufferConverter
550dc4cfdbbSAart Bik     : public OpConversionPattern<ToCoordinatesBufferOp> {
551dc4cfdbbSAart Bik public:
552dc4cfdbbSAart Bik   using OpConversionPattern::OpConversionPattern;
553dc4cfdbbSAart Bik   LogicalResult
554dc4cfdbbSAart Bik   matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
555dc4cfdbbSAart Bik                   ConversionPatternRewriter &rewriter) const override {
556dc4cfdbbSAart Bik     const Location loc = op.getLoc();
557dc4cfdbbSAart Bik     auto stt = getSparseTensorType(op.getTensor());
558dc4cfdbbSAart Bik     auto crds = genCoordinatesBufferCall(
559dc4cfdbbSAart Bik         rewriter, loc, stt, adaptor.getTensor(), stt.getAoSCOOStart());
560dc4cfdbbSAart Bik     // Cast the MemRef type to the type expected by the users, though these
561dc4cfdbbSAart Bik     // two types should be compatible at runtime.
562dc4cfdbbSAart Bik     if (op.getType() != crds.getType())
563dc4cfdbbSAart Bik       crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
564af8428c0SAart Bik     rewriter.replaceOp(op, crds);
565a2c9d4bbSAart Bik     return success();
566a2c9d4bbSAart Bik   }
567a2c9d4bbSAart Bik };
568a2c9d4bbSAart Bik 
569a2c9d4bbSAart Bik /// Sparse conversion rule for value accesses.
57096a23911SAart Bik class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
571a2c9d4bbSAart Bik public:
572a2c9d4bbSAart Bik   using OpConversionPattern::OpConversionPattern;
573a2c9d4bbSAart Bik   LogicalResult
574b54c724bSRiver Riddle   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
575a2c9d4bbSAart Bik                   ConversionPatternRewriter &rewriter) const override {
576af8428c0SAart Bik     auto stt = getSparseTensorType(op.getTensor());
577af8428c0SAart Bik     auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
578af8428c0SAart Bik     rewriter.replaceOp(op, vals);
5790f3e4d1aSAart Bik     return success();
5800f3e4d1aSAart Bik   }
5810f3e4d1aSAart Bik };
5820f3e4d1aSAart Bik 
5830f3e4d1aSAart Bik /// Sparse conversion rule for number of entries operator.
5840f3e4d1aSAart Bik class SparseNumberOfEntriesConverter
5850f3e4d1aSAart Bik     : public OpConversionPattern<NumberOfEntriesOp> {
5860f3e4d1aSAart Bik public:
5870f3e4d1aSAart Bik   using OpConversionPattern::OpConversionPattern;
5880f3e4d1aSAart Bik   LogicalResult
5890f3e4d1aSAart Bik   matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
5900f3e4d1aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
5910f3e4d1aSAart Bik     // Query values array size for the actually stored values size.
592af8428c0SAart Bik     auto stt = getSparseTensorType(op.getTensor());
593af8428c0SAart Bik     auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
594af8428c0SAart Bik     auto zero = constantIndex(rewriter, op.getLoc(), 0);
595af8428c0SAart Bik     rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
596a2c9d4bbSAart Bik     return success();
597a2c9d4bbSAart Bik   }
598a2c9d4bbSAart Bik };
599a2c9d4bbSAart Bik 
600f66e5769SAart Bik /// Sparse conversion rule for tensor rematerialization.
601f66e5769SAart Bik class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
602727a63e0SAart Bik public:
603727a63e0SAart Bik   using OpConversionPattern::OpConversionPattern;
604727a63e0SAart Bik   LogicalResult
605f66e5769SAart Bik   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
606727a63e0SAart Bik                   ConversionPatternRewriter &rewriter) const override {
6078df54a6aSJacques Pienaar     if (op.getHasInserts()) {
608f66e5769SAart Bik       // Finalize any pending insertions.
6092045cca0SAart Bik       StringRef name = "endLexInsert";
610ee986ab7SPeiming Liu       createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
611d8731bfcSwren romano                      EmitCInterface::Off);
61236b66ab9SAart Bik     }
613f66e5769SAart Bik     rewriter.replaceOp(op, adaptor.getOperands());
614f66e5769SAart Bik     return success();
61536b66ab9SAart Bik   }
616f66e5769SAart Bik };
617f66e5769SAart Bik 
618f76dcedeSAart Bik /// Sparse conversion rule for the insertion operator.
61994e27c26SPeiming Liu class SparseTensorInsertConverter
62094e27c26SPeiming Liu     : public OpConversionPattern<tensor::InsertOp> {
621f66e5769SAart Bik public:
622f66e5769SAart Bik   using OpConversionPattern::OpConversionPattern;
623f66e5769SAart Bik   LogicalResult
62494e27c26SPeiming Liu   matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
625f66e5769SAart Bik                   ConversionPatternRewriter &rewriter) const override {
626f76dcedeSAart Bik     // Note that the current regime only allows for strict lexicographic
62784cd51bbSwren romano     // coordinate order. All values are passed by reference through stack
628a3610359SAart Bik     // allocated memrefs.
629a3610359SAart Bik     Location loc = op->getLoc();
63094e27c26SPeiming Liu     const auto stt = getSparseTensorType(op.getDest());
63194e27c26SPeiming Liu 
63294e27c26SPeiming Liu     // Dense tensor insertion.
63394e27c26SPeiming Liu     if (!stt.hasEncoding())
63494e27c26SPeiming Liu       return failure();
63594e27c26SPeiming Liu 
63694e27c26SPeiming Liu     assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
637f708a549Swren romano     const auto elemTp = stt.getElementType();
63884cd51bbSwren romano     const Level lvlRank = stt.getLvlRank();
6396243d7d2SPeiming Liu     Value lvlCoords, vref;
6406243d7d2SPeiming Liu     {
6416243d7d2SPeiming Liu       OpBuilder::InsertionGuard guard(rewriter);
64243961264SPeiming Liu       Operation *loop = op;
64343961264SPeiming Liu       // Finds the outermost loop.
64443961264SPeiming Liu       while (auto l = loop->getParentOfType<LoopLikeOpInterface>())
64543961264SPeiming Liu         loop = l;
64643961264SPeiming Liu 
64743961264SPeiming Liu       if (llvm::isa<LoopLikeOpInterface>(loop)) {
6486243d7d2SPeiming Liu         // Hoists alloca outside the loop to avoid stack overflow.
6496243d7d2SPeiming Liu         rewriter.setInsertionPoint(loop);
6506243d7d2SPeiming Liu       }
6516243d7d2SPeiming Liu       lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
6526243d7d2SPeiming Liu       vref = genAllocaScalar(rewriter, loc, elemTp);
6536243d7d2SPeiming Liu     }
65494e27c26SPeiming Liu     storeAll(rewriter, loc, lvlCoords, adaptor.getIndices());
65594e27c26SPeiming Liu     rewriter.create<memref::StoreOp>(loc, adaptor.getScalar(), vref);
656c9489225Swren romano     SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
65784cd51bbSwren romano     createFuncCall(rewriter, loc, name, {},
65894e27c26SPeiming Liu                    {adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On);
65994e27c26SPeiming Liu     rewriter.replaceOp(op, adaptor.getDest());
66036b66ab9SAart Bik     return success();
661727a63e0SAart Bik   }
662727a63e0SAart Bik };
663727a63e0SAart Bik 
664faa00c13SAart Bik /// Sparse conversion rule for the expand operator.
6654f2ec7f9SAart Bik class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
6664f2ec7f9SAart Bik public:
6674f2ec7f9SAart Bik   using OpConversionPattern::OpConversionPattern;
6684f2ec7f9SAart Bik   LogicalResult
6694f2ec7f9SAart Bik   matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
6704f2ec7f9SAart Bik                   ConversionPatternRewriter &rewriter) const override {
6714f2ec7f9SAart Bik     Location loc = op->getLoc();
672f708a549Swren romano     const auto srcTp = getSparseTensorType(op.getTensor());
673f708a549Swren romano     Type eltType = srcTp.getElementType();
6744f2ec7f9SAart Bik     Type boolType = rewriter.getIntegerType(1);
6754f2ec7f9SAart Bik     Type idxType = rewriter.getIndexType();
6764f2ec7f9SAart Bik     // All initialization should be done on entry of the loop nest.
6778df54a6aSJacques Pienaar     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
67886f91e45Swren romano     // Get the cardinality of valid coordinates for the innermost level.
679f708a549Swren romano     Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
680f708a549Swren romano                                    srcTp.getLvlRank() - 1);
68184cd51bbSwren romano     // Allocate temporary buffers for values, filled-switch, and coordinates.
6820b55f94dSAart Bik     // We do not use stack buffers for this, since the expanded size may
6830b55f94dSAart Bik     // be rather large (as it envelops a single expanded dense dimension).
6840b55f94dSAart Bik     Value values = genAlloc(rewriter, loc, sz, eltType);
6850b55f94dSAart Bik     Value filled = genAlloc(rewriter, loc, sz, boolType);
68684cd51bbSwren romano     Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
6874f2ec7f9SAart Bik     Value zero = constantZero(rewriter, loc, idxType);
6884f2ec7f9SAart Bik     // Reset the values/filled-switch to all-zero/false. Note that this
6894f2ec7f9SAart Bik     // introduces an O(N) operation into the computation, but this reset
6904f2ec7f9SAart Bik     // operation is amortized over the innermost loops for the access
6910b55f94dSAart Bik     // pattern expansion. As noted in the operation doc, we would like
6920b55f94dSAart Bik     // to amortize this setup cost even between kernels.
6937294be2bSgysit     rewriter.create<linalg::FillOp>(
6947294be2bSgysit         loc, ValueRange{constantZero(rewriter, loc, eltType)},
6957294be2bSgysit         ValueRange{values});
6967294be2bSgysit     rewriter.create<linalg::FillOp>(
6977294be2bSgysit         loc, ValueRange{constantZero(rewriter, loc, boolType)},
6987294be2bSgysit         ValueRange{filled});
69984cd51bbSwren romano     // Replace expansion op with these buffers and initial coordinate.
7004f2ec7f9SAart Bik     assert(op.getNumResults() == 4);
70184cd51bbSwren romano     rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
7024f2ec7f9SAart Bik     return success();
7034f2ec7f9SAart Bik   }
7044f2ec7f9SAart Bik };
7054f2ec7f9SAart Bik 
706faa00c13SAart Bik /// Sparse conversion rule for the compress operator.
7074f2ec7f9SAart Bik class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
7084f2ec7f9SAart Bik public:
7094f2ec7f9SAart Bik   using OpConversionPattern::OpConversionPattern;
7104f2ec7f9SAart Bik   LogicalResult
7114f2ec7f9SAart Bik   matchAndRewrite(CompressOp op, OpAdaptor adaptor,
7124f2ec7f9SAart Bik                   ConversionPatternRewriter &rewriter) const override {
7130b55f94dSAart Bik     Location loc = op->getLoc();
7144f2ec7f9SAart Bik     // Note that this method call resets the values/filled-switch back to
7154f2ec7f9SAart Bik     // all-zero/false by only iterating over the set elements, so the
7164f2ec7f9SAart Bik     // complexity remains proportional to the sparsity of the expanded
7174f2ec7f9SAart Bik     // access pattern.
718a3610359SAart Bik     Value values = adaptor.getValues();
719a3610359SAart Bik     Value filled = adaptor.getFilled();
720a3610359SAart Bik     Value added = adaptor.getAdded();
721a3610359SAart Bik     Value count = adaptor.getCount();
722a3610359SAart Bik     Value tensor = adaptor.getTensor();
723f708a549Swren romano     const auto stt = getSparseTensorType(op.getTensor());
724f708a549Swren romano     const Type elemTp = stt.getElementType();
72584cd51bbSwren romano     const Level lvlRank = stt.getLvlRank();
72684cd51bbSwren romano     auto lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
72784cd51bbSwren romano     storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
728c9489225Swren romano     SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
7299f596a7cSAart Bik     createFuncCall(rewriter, loc, name, {},
73084cd51bbSwren romano                    {tensor, lvlCoords, values, filled, added, count},
731bb8632c1SAart Bik                    EmitCInterface::On);
7329f596a7cSAart Bik     rewriter.replaceOp(op, adaptor.getTensor());
7330b55f94dSAart Bik     // Deallocate the buffers on exit of the loop nest.
7345661647eSAart Bik     Operation *parent = getTop(op);
7350b55f94dSAart Bik     rewriter.setInsertionPointAfter(parent);
736a3610359SAart Bik     rewriter.create<memref::DeallocOp>(loc, values);
737a3610359SAart Bik     rewriter.create<memref::DeallocOp>(loc, filled);
738a3610359SAart Bik     rewriter.create<memref::DeallocOp>(loc, added);
7394f2ec7f9SAart Bik     return success();
7404f2ec7f9SAart Bik   }
7414f2ec7f9SAart Bik };
7424f2ec7f9SAart Bik 
743af8428c0SAart Bik /// Sparse conversion rule for the sparse_tensor.assemble operator.
7446ca47eb4SPeiming Liu class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
745fa6726e2SPeiming Liu public:
746fa6726e2SPeiming Liu   using OpConversionPattern::OpConversionPattern;
747fa6726e2SPeiming Liu   LogicalResult
7486ca47eb4SPeiming Liu   matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
749fa6726e2SPeiming Liu                   ConversionPatternRewriter &rewriter) const override {
750fa6726e2SPeiming Liu     const Location loc = op->getLoc();
751fa6726e2SPeiming Liu     const auto dstTp = getSparseTensorType(op.getResult());
752fa6726e2SPeiming Liu     assert(dstTp.hasStaticDimShape());
7532323f48eSAart Bik     SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, dstTp);
754af8428c0SAart Bik     // Use a library method to transfer the external buffers from
755af8428c0SAart Bik     // clients to the internal SparseTensorStorage. Since we cannot
756af8428c0SAart Bik     // assume clients transfer ownership of the buffers, this method
757af8428c0SAart Bik     // will copy all data over into a new SparseTensorStorage.
758fa6726e2SPeiming Liu     Value dst =
759fa6726e2SPeiming Liu         NewCallParams(rewriter, loc)
7602323f48eSAart Bik             .genBuffers(dstTp.withoutDimToLvl(), dimSizesValues)
761fa6726e2SPeiming Liu             .genNewCall(Action::kPack,
762fa6726e2SPeiming Liu                         genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
763fa6726e2SPeiming Liu                                           adaptor.getValues()));
764fa6726e2SPeiming Liu     rewriter.replaceOp(op, dst);
765fa6726e2SPeiming Liu     return success();
766fa6726e2SPeiming Liu   }
767fa6726e2SPeiming Liu };
768fa6726e2SPeiming Liu 
769af8428c0SAart Bik /// Sparse conversion rule for the sparse_tensor.disassemble operator.
7705122a2c2SAart Bik /// Note that the current implementation simply exposes the buffers to
7715122a2c2SAart Bik /// the external client. This assumes the client only reads the buffers
7725122a2c2SAart Bik /// (usually copying it to the external data structures, such as numpy
7735122a2c2SAart Bik /// arrays). The semantics of the disassemble operation technically
7745122a2c2SAart Bik /// require that the copying is done here already using the out-levels
7755122a2c2SAart Bik /// and out-values clause.
776af8428c0SAart Bik class SparseTensorDisassembleConverter
777af8428c0SAart Bik     : public OpConversionPattern<DisassembleOp> {
778af8428c0SAart Bik public:
779af8428c0SAart Bik   using OpConversionPattern::OpConversionPattern;
780af8428c0SAart Bik   LogicalResult
781af8428c0SAart Bik   matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
782af8428c0SAart Bik                   ConversionPatternRewriter &rewriter) const override {
783af8428c0SAart Bik     Location loc = op->getLoc();
784af8428c0SAart Bik     auto stt = getSparseTensorType(op.getTensor());
785af8428c0SAart Bik     SmallVector<Value> retVal;
786af8428c0SAart Bik     SmallVector<Value> retLen;
787fc9f1d49SPeiming Liu     // Get the positions and coordinates buffers.
788af8428c0SAart Bik     const Level lvlRank = stt.getLvlRank();
789af8428c0SAart Bik     Level trailCOOLen = 0;
790af8428c0SAart Bik     for (Level l = 0; l < lvlRank; l++) {
791af8428c0SAart Bik       if (!stt.isUniqueLvl(l) &&
792af8428c0SAart Bik           (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
793af8428c0SAart Bik         // A `(loose)compressed_nu` level marks the start of trailing COO
794af8428c0SAart Bik         // start level. Since the target coordinate buffer used for trailing
795af8428c0SAart Bik         // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
796af8428c0SAart Bik         // scheme, we cannot simply use the internal buffers.
797af8428c0SAart Bik         trailCOOLen = lvlRank - l;
798af8428c0SAart Bik         break;
799af8428c0SAart Bik       }
800af8428c0SAart Bik       if (stt.isWithPos(l)) {
801af8428c0SAart Bik         auto poss =
802af8428c0SAart Bik             genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
803af8428c0SAart Bik         auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
804fc9f1d49SPeiming Liu         auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
805af8428c0SAart Bik         retVal.push_back(poss);
806af8428c0SAart Bik         retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
807af8428c0SAart Bik       }
808af8428c0SAart Bik       if (stt.isWithCrd(l)) {
809af8428c0SAart Bik         auto crds =
810af8428c0SAart Bik             genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
811af8428c0SAart Bik         auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
812fc9f1d49SPeiming Liu         auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
813af8428c0SAart Bik         retVal.push_back(crds);
814af8428c0SAart Bik         retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
815af8428c0SAart Bik       }
816af8428c0SAart Bik     }
817af8428c0SAart Bik     // Handle AoS vs. SoA mismatch for COO.
818af8428c0SAart Bik     if (trailCOOLen != 0) {
819af8428c0SAart Bik       uint64_t cooStartLvl = lvlRank - trailCOOLen;
820af8428c0SAart Bik       assert(!stt.isUniqueLvl(cooStartLvl) &&
821af8428c0SAart Bik              (stt.isCompressedLvl(cooStartLvl) ||
822af8428c0SAart Bik               stt.isLooseCompressedLvl(cooStartLvl)));
823af8428c0SAart Bik       // Positions.
824af8428c0SAart Bik       auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
825af8428c0SAart Bik                                    cooStartLvl);
826af8428c0SAart Bik       auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
827fc9f1d49SPeiming Liu       auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
828af8428c0SAart Bik       retVal.push_back(poss);
829af8428c0SAart Bik       retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
830af8428c0SAart Bik       // Coordinates, copied over with:
831af8428c0SAart Bik       //    for (i = 0; i < crdLen; i++)
832af8428c0SAart Bik       //       buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
833fc9f1d49SPeiming Liu       auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
834af8428c0SAart Bik       auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
835af8428c0SAart Bik                                       cooStartLvl);
836af8428c0SAart Bik       auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
837af8428c0SAart Bik                                       cooStartLvl + 1);
838af8428c0SAart Bik       auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0);
839af8428c0SAart Bik       auto two = constantIndex(rewriter, loc, 2);
840af8428c0SAart Bik       auto bufLen = rewriter.create<arith::MulIOp>(loc, crdLen, two);
841af8428c0SAart Bik       Type indexType = rewriter.getIndexType();
842af8428c0SAart Bik       auto zero = constantZero(rewriter, loc, indexType);
843af8428c0SAart Bik       auto one = constantOne(rewriter, loc, indexType);
844af8428c0SAart Bik       scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, zero, crdLen, one);
845af8428c0SAart Bik       auto idx = forOp.getInductionVar();
846af8428c0SAart Bik       rewriter.setInsertionPointToStart(forOp.getBody());
847af8428c0SAart Bik       auto c0 = rewriter.create<memref::LoadOp>(loc, crds0, idx);
848af8428c0SAart Bik       auto c1 = rewriter.create<memref::LoadOp>(loc, crds1, idx);
849af8428c0SAart Bik       SmallVector<Value> args;
850af8428c0SAart Bik       args.push_back(idx);
851af8428c0SAart Bik       args.push_back(zero);
852af8428c0SAart Bik       rewriter.create<memref::StoreOp>(loc, c0, buf, args);
853af8428c0SAart Bik       args[1] = one;
854af8428c0SAart Bik       rewriter.create<memref::StoreOp>(loc, c1, buf, args);
855af8428c0SAart Bik       rewriter.setInsertionPointAfter(forOp);
856fc9f1d49SPeiming Liu       auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
857af8428c0SAart Bik       retVal.push_back(buf);
858af8428c0SAart Bik       retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
859af8428c0SAart Bik     }
860fc9f1d49SPeiming Liu     // Get the values buffer last.
861fc9f1d49SPeiming Liu     auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
862fc9f1d49SPeiming Liu     auto valLenTp = op.getValLen().getType();
863fc9f1d49SPeiming Liu     auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
864fc9f1d49SPeiming Liu     retVal.push_back(vals);
865fc9f1d49SPeiming Liu     retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
866fc9f1d49SPeiming Liu 
867af8428c0SAart Bik     // Converts MemRefs back to Tensors.
868af8428c0SAart Bik     assert(retVal.size() + retLen.size() == op.getNumResults());
869af8428c0SAart Bik     for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
870af8428c0SAart Bik       auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]);
871af8428c0SAart Bik       retVal[i] =
872af8428c0SAart Bik           rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
873af8428c0SAart Bik     }
874fc9f1d49SPeiming Liu 
875af8428c0SAart Bik     // Appends the actual memory length used in each buffer returned.
876af8428c0SAart Bik     retVal.append(retLen.begin(), retLen.end());
877af8428c0SAart Bik     rewriter.replaceOp(op, retVal);
878af8428c0SAart Bik     return success();
879af8428c0SAart Bik   }
880af8428c0SAart Bik };
881af8428c0SAart Bik 
882e8e8df4cSMatthias Springer struct SparseHasRuntimeLibraryConverter
883e8e8df4cSMatthias Springer     : public OpConversionPattern<HasRuntimeLibraryOp> {
884e8e8df4cSMatthias Springer   using OpConversionPattern::OpConversionPattern;
885e8e8df4cSMatthias Springer   LogicalResult
886e8e8df4cSMatthias Springer   matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
887e8e8df4cSMatthias Springer                   ConversionPatternRewriter &rewriter) const override {
888e8e8df4cSMatthias Springer     auto i1Type = rewriter.getI1Type();
889e8e8df4cSMatthias Springer     rewriter.replaceOpWithNewOp<arith::ConstantOp>(
890e8e8df4cSMatthias Springer         op, i1Type, rewriter.getIntegerAttr(i1Type, 1));
891e8e8df4cSMatthias Springer     return success();
892e8e8df4cSMatthias Springer   }
893e8e8df4cSMatthias Springer };
894e8e8df4cSMatthias Springer 
895a2c9d4bbSAart Bik } // namespace
896a2c9d4bbSAart Bik 
89705c7f450SAart Bik //===----------------------------------------------------------------------===//
89886b22d31SAart Bik // Sparse tensor type conversion into opaque pointer.
89986b22d31SAart Bik //===----------------------------------------------------------------------===//
90086b22d31SAart Bik 
90186b22d31SAart Bik mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
90286b22d31SAart Bik   addConversion([](Type type) { return type; });
90386b22d31SAart Bik   addConversion(convertSparseTensorTypes);
90486b22d31SAart Bik }
90586b22d31SAart Bik 
90686b22d31SAart Bik //===----------------------------------------------------------------------===//
90705c7f450SAart Bik // Public method for populating conversion rules.
90805c7f450SAart Bik //===----------------------------------------------------------------------===//
90905c7f450SAart Bik 
910a2c9d4bbSAart Bik /// Populates the given patterns list with conversion rules required for
911a2c9d4bbSAart Bik /// the sparsification of linear algebra operations.
912*206fad0eSMatthias Springer void mlir::populateSparseTensorConversionPatterns(
913*206fad0eSMatthias Springer     const TypeConverter &typeConverter, RewritePatternSet &patterns) {
914f248d0b2SPeiming Liu   patterns
915c780352dSPeiming Liu       .add<SparseReturnConverter, SparseTensorLvlOpConverter,
916ef222988SPeiming Liu            SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
917c3b01b46SPeiming Liu            SparseTensorAllocConverter, SparseTensorEmptyConverter,
918f248d0b2SPeiming Liu            SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
919f248d0b2SPeiming Liu            SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
920dc4cfdbbSAart Bik            SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter,
921dc4cfdbbSAart Bik            SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
922dc4cfdbbSAart Bik            SparseTensorInsertConverter, SparseTensorExpandConverter,
923dc4cfdbbSAart Bik            SparseTensorCompressConverter, SparseTensorAssembleConverter,
924dc4cfdbbSAart Bik            SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
925dc4cfdbbSAart Bik           typeConverter, patterns.getContext());
926a2c9d4bbSAart Bik }
927