xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (revision 9df63b2651b2435c02a7d825953ca2ddc65c778e)
1 //===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor types and primitives to actual compiler
10 // visible buffers and actual compiler IR that implements these primitives on
11 // the selected sparse tensor storage schemes. This pass provides an alternative
12 // to the SparseTensorConversion pass, eliminating the dependence on a runtime
13 // support library (other than for file I/O), and providing many more
14 // opportunities for subsequent compiler optimization of the generated code.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "Utils/CodegenUtils.h"
19 #include "Utils/SparseTensorDescriptor.h"
20 
21 #include "mlir/Dialect/Arith/Utils/Utils.h"
22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Dialect/Linalg/Utils/Utils.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
30 #include "mlir/Dialect/Tensor/IR/Tensor.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::sparse_tensor;
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Flatten the given value ranges into a single vector of values.
43 static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
44   SmallVector<Value> result;
45   for (const auto &vals : values)
46     llvm::append_range(result, vals);
47   return result;
48 }
49 
50 /// Assert that the given value range contains a single value and return it.
51 static Value getSingleValue(ValueRange values) {
52   assert(values.size() == 1 && "expected single value");
53   return values.front();
54 }
55 
56 /// Generates a load with proper `index` typing.
57 static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
58   idx = genCast(builder, loc, idx, builder.getIndexType());
59   return builder.create<memref::LoadOp>(loc, mem, idx);
60 }
61 
62 /// Generates a store with proper `index` typing and proper value.
63 static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
64                      Value idx) {
65   idx = genCast(builder, loc, idx, builder.getIndexType());
66   val = genCast(builder, loc, val,
67                 cast<ShapedType>(mem.getType()).getElementType());
68   builder.create<memref::StoreOp>(loc, val, mem, idx);
69 }
70 
71 /// Creates a straightforward counting for-loop.
72 static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
73                             MutableArrayRef<Value> fields,
74                             Value lower = Value()) {
75   Type indexType = builder.getIndexType();
76   if (!lower)
77     lower = constantZero(builder, loc, indexType);
78   Value one = constantOne(builder, loc, indexType);
79   scf::ForOp forOp = builder.create<scf::ForOp>(loc, lower, upper, one, fields);
80   for (unsigned i = 0, e = fields.size(); i < e; i++)
81     fields[i] = forOp.getRegionIterArg(i);
82   builder.setInsertionPointToStart(forOp.getBody());
83   return forOp;
84 }
85 
86 /// Creates a push back operation.
87 static void createPushback(OpBuilder &builder, Location loc,
88                            MutSparseTensorDescriptor desc,
89                            SparseTensorFieldKind kind, std::optional<Level> lvl,
90                            Value value, Value repeat = Value()) {
91   Type etp = desc.getMemRefElementType(kind, lvl);
92   Value field = desc.getMemRefField(kind, lvl);
93   StorageSpecifierKind specFieldKind = toSpecifierKind(kind);
94 
95   auto pushBackOp = builder.create<PushBackOp>(
96       loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field,
97       genCast(builder, loc, value, etp), repeat);
98 
99   desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
100   desc.setSpecifierField(builder, loc, specFieldKind, lvl,
101                          pushBackOp.getNewSize());
102 }
103 
104 /// Generates code that allocates a sparse storage scheme for given rank.
105 static void allocSchemeForRank(OpBuilder &builder, Location loc,
106                                MutSparseTensorDescriptor desc, Level startLvl) {
107   const SparseTensorType stt(desc.getRankedTensorType());
108   Value linear = constantIndex(builder, loc, 1);
109   const Level lvlRank = stt.getLvlRank();
110   for (Level lvl = startLvl; lvl < lvlRank; lvl++) {
111     const auto lt = stt.getLvlType(lvl);
112     if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
113       // Append linear x positions, initialized to zero. Since each compressed
114       // dimension initially already has a single zero entry, this maintains
115       // the desired "linear + 1" length property at all times. For loose
116       // compression, we multiply linear by two in order to append both the
117       // lo/hi positions.
118       Value posZero = constantZero(builder, loc, stt.getPosType());
119       if (isLooseCompressedLT(lt)) {
120         Value two = constantIndex(builder, loc, 2);
121         linear = builder.create<arith::MulIOp>(loc, linear, two);
122       }
123       createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
124                      /*value=*/posZero, /*repeat=*/linear);
125       return;
126     } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
127       return; // nothing to do
128     }
129     // Keep compounding the size, but nothing needs to be initialized
130     // at this level. We will eventually reach a compressed level or
131     // otherwise the values array for the from-here "all-dense" case.
132     assert(isDenseLT(lt));
133     Value size = desc.getLvlSize(builder, loc, lvl);
134     linear = builder.create<arith::MulIOp>(loc, linear, size);
135   }
136   // Reached values array so prepare for an insertion.
137   Value valZero = constantZero(builder, loc, stt.getElementType());
138   createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
139                  std::nullopt, /*value=*/valZero, /*repeat=*/linear);
140 }
141 
142 /// Creates allocation operation.
143 static Value createAllocation(OpBuilder &builder, Location loc,
144                               MemRefType memRefType, Value sz,
145                               bool enableInit) {
146   Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
147   Type elemType = memRefType.getElementType();
148   if (enableInit) {
149     Value fillValue = constantZero(builder, loc, elemType);
150     builder.create<linalg::FillOp>(loc, fillValue, buffer);
151   }
152   return buffer;
153 }
154 
155 /// Creates the dim sizes array, filling in from dynamic sizes.
156 static void createDimSizes(OpBuilder &builder, Location loc,
157                            SparseTensorType stt, ValueRange dynSizes,
158                            /*out*/ SmallVectorImpl<Value> &dimSizesValues) {
159   const Dimension dimRank = stt.getDimRank();
160   dimSizesValues.clear();
161   dimSizesValues.reserve(dimRank);
162   unsigned i = 0;
163   for (const Size sz : stt.getDimShape())
164     dimSizesValues.push_back(ShapedType::isDynamic(sz)
165                                  ? dynSizes[i++]
166                                  : constantIndex(builder, loc, sz));
167 }
168 
169 /// Creates allocation for each field in sparse tensor type. Note that
170 /// for all dynamic memrefs in the sparse tensor stroage layout, the
171 /// memory size is really the capacity of the "vector", while the actual
172 /// size resides in the sizes array.
173 static void createAllocFields(OpBuilder &builder, Location loc,
174                               SparseTensorType stt, bool enableInit,
175                               Value sizeHint,
176                               SmallVectorImpl<Value> &lvlSizesValues,
177                               /*out*/ SmallVectorImpl<Value> &fields) {
178   Level lvlRank = stt.getLvlRank();
179   // Set up some heuristic sizes. We try to set the initial
180   // size based on available information. Otherwise we just
181   // initialize a few elements to start the reallocation chain.
182   // TODO: refine this
183   Value posHeuristic, crdHeuristic, valHeuristic;
184   if (stt.isAllDense()) {
185     valHeuristic = lvlSizesValues[0];
186     for (Level lvl = 1; lvl < lvlRank; lvl++)
187       valHeuristic =
188           builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
189   } else if (sizeHint) {
190     if (stt.getAoSCOOStart() == 0) {
191       posHeuristic = constantIndex(builder, loc, 2);
192       crdHeuristic = builder.create<arith::MulIOp>(
193           loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
194     } else if (lvlRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) {
195       posHeuristic = builder.create<arith::AddIOp>(
196           loc, sizeHint, constantIndex(builder, loc, 1));
197       crdHeuristic = sizeHint;
198     } else {
199       posHeuristic = crdHeuristic = constantIndex(builder, loc, 16);
200     }
201     valHeuristic = sizeHint;
202   } else {
203     posHeuristic = crdHeuristic = valHeuristic =
204         constantIndex(builder, loc, 16);
205   }
206   // Initializes all fields. An initial storage specifier and allocated
207   // positions/coordinates/values memrefs (with heuristic capacity).
208   foreachFieldAndTypeInSparseTensor(
209       stt,
210       [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
211        enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
212                    Level /*lvl*/, LevelType /*lt*/) -> bool {
213         assert(fields.size() == fIdx);
214         Value field;
215         switch (fKind) {
216         case SparseTensorFieldKind::StorageSpec:
217           field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
218           break;
219         case SparseTensorFieldKind::PosMemRef:
220           field = createAllocation(builder, loc, cast<MemRefType>(fType),
221                                    posHeuristic, enableInit);
222           break;
223         case SparseTensorFieldKind::CrdMemRef:
224           field = createAllocation(builder, loc, cast<MemRefType>(fType),
225                                    crdHeuristic, enableInit);
226           break;
227         case SparseTensorFieldKind::ValMemRef:
228           field = createAllocation(builder, loc, cast<MemRefType>(fType),
229                                    valHeuristic, enableInit);
230           break;
231         }
232         assert(field);
233         fields.push_back(field);
234         // Returns true to continue the iteration.
235         return true;
236       });
237   // Initialize the storage scheme to an empty tensor. Sets the lvlSizes
238   // and gives all position fields an initial zero entry, so that it is
239   // easier to maintain the "linear + 1" length property.
240   MutSparseTensorDescriptor desc(stt, fields);
241   Value posZero = constantZero(builder, loc, stt.getPosType());
242   for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
243     desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
244     const auto lt = stt.getLvlType(lvl);
245     if (isCompressedLT(lt) || isLooseCompressedLT(lt))
246       createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
247                      /*value=*/posZero);
248   }
249   allocSchemeForRank(builder, loc, desc, /*rank=*/0);
250 }
251 
252 /// Helper method that generates block specific to compressed case:
253 ///
254 ///  // given: parentPos = posCursor[lvl-1]
255 ///  pstart = desc.positions[lvl][parentPos]
256 ///  pstop = desc.positions[lvl][parentPos+1]
257 ///  plast = pstop - 1
258 ///  msz = desc.coordinates[lvl].size()
259 ///  if (pstart < pstop) {
260 ///    isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl])
261 ///  } else { // first insertion
262 ///    isPresent = false
263 ///    desc.positions[lvl][parentPos] = msz
264 ///  }
265 ///  if (isPresent) { // coordinate is already present
266 ///    pnext = plast
267 ///  } else {
268 ///    desc.coordinates[lvl].push_back(lvlCoords[lvl])
269 ///    desc.positions[lvl][parentPos+1] = msz+1
270 ///    pnext = msz
271 ///    <prepare level lvl+1>
272 ///  }
273 ///  posCursor[lvl] = pnext
274 static Value genCompressed(OpBuilder &builder, Location loc,
275                            MutSparseTensorDescriptor desc, ValueRange lvlCoords,
276                            Value /*unused*/, Value parentPos, Level lvl) {
277   const SparseTensorType stt(desc.getRankedTensorType());
278   const Level lvlRank = stt.getLvlRank();
279   assert(lvl < lvlRank && "Level is out of bounds");
280   assert(lvlCoords.size() == static_cast<size_t>(lvlRank) &&
281          "Level-rank mismatch");
282   SmallVector<Type> types;
283   Type indexType = builder.getIndexType();
284   Type boolType = builder.getIntegerType(1);
285   unsigned crdFidx;
286   unsigned crdStride;
287   std::tie(crdFidx, crdStride) = desc.getCrdMemRefIndexAndStride(lvl);
288   const Value one = constantIndex(builder, loc, 1);
289   const Value pp1 = builder.create<arith::AddIOp>(loc, parentPos, one);
290   const Value positionsAtLvl = desc.getPosMemRef(lvl);
291   const Value pstart = genLoad(builder, loc, positionsAtLvl, parentPos);
292   const Value pstop = genLoad(builder, loc, positionsAtLvl, pp1);
293   const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl);
294   const Value crdStrideC =
295       crdStride > 1 ? constantIndex(builder, loc, crdStride) : Value();
296   const Value msz =
297       crdStrideC ? builder.create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
298                  : crdMsz;
299   const Value plast = builder.create<arith::SubIOp>(
300       loc, genCast(builder, loc, pstop, indexType), one);
301   // Conditional expression.
302   Value lt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
303                                            pstart, pstop);
304   types.push_back(boolType);
305   scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
306   types.pop_back();
307   builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
308   Value crd =
309       genLoad(builder, loc, desc.getMemRefField(crdFidx),
310               crdStrideC ? builder.create<arith::MulIOp>(loc, plast, crdStrideC)
311                          : plast);
312   Value eq = builder.create<arith::CmpIOp>(
313       loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType),
314       lvlCoords[lvl]);
315   builder.create<scf::YieldOp>(loc, eq);
316   builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
317   if (lvl > 0)
318     genStore(builder, loc, msz, positionsAtLvl, parentPos);
319   builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
320   builder.setInsertionPointAfter(ifOp1);
321   // If present construct. Note that for a non-unique dimension level, we
322   // simply set the condition to false and rely on CSE/DCE to clean up the IR.
323   //
324   // TODO: generate less temporary IR?
325   //
326   for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
327     types.push_back(desc.getField(i).getType());
328   types.push_back(indexType);
329   const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0)
330                                        : constantI1(builder, loc, false);
331   scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true);
332   // If present (fields unaffected, update pnext to plast).
333   builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
334 
335   // FIXME: This does not looks like a clean way, but probably the most
336   // efficient way.
337   desc.getFields().push_back(plast);
338   builder.create<scf::YieldOp>(loc, desc.getFields());
339   desc.getFields().pop_back();
340 
341   // If !present (changes fields, update pnext).
342   builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
343   Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
344   genStore(builder, loc, mszp1, positionsAtLvl, pp1);
345   createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl,
346                  /*value=*/lvlCoords[lvl]);
347   // Prepare the next level "as needed".
348   if ((lvl + 1) < lvlRank)
349     allocSchemeForRank(builder, loc, desc, lvl + 1);
350 
351   desc.getFields().push_back(msz);
352   builder.create<scf::YieldOp>(loc, desc.getFields());
353   desc.getFields().pop_back();
354 
355   // Update fields and return next pos.
356   builder.setInsertionPointAfter(ifOp2);
357   unsigned o = 0;
358   for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
359     desc.setField(i, ifOp2.getResult(o++));
360   return ifOp2.getResult(o);
361 }
362 
363 /// Generates insertion finalization code.
364 static void genEndInsert(OpBuilder &builder, Location loc,
365                          SparseTensorDescriptor desc) {
366   const SparseTensorType stt(desc.getRankedTensorType());
367   const Level lvlRank = stt.getLvlRank();
368   for (Level lvl = 0; lvl < lvlRank; lvl++) {
369     const auto lt = stt.getLvlType(lvl);
370     if (isCompressedLT(lt)) {
371       // Compressed dimensions need a position cleanup for all entries
372       // that were not visited during the insertion pass.
373       //
374       // TODO: avoid cleanup and keep compressed scheme consistent at all
375       // times?
376       //
377       if (lvl > 0) {
378         Type posType = stt.getPosType();
379         Value posMemRef = desc.getPosMemRef(lvl);
380         Value hi = desc.getPosMemSize(builder, loc, lvl);
381         Value zero = constantIndex(builder, loc, 0);
382         Value one = constantIndex(builder, loc, 1);
383         // Vector of only one, but needed by createFor's prototype.
384         SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)};
385         scf::ForOp loop = createFor(builder, loc, hi, inits, one);
386         Value i = loop.getInductionVar();
387         Value oldv = loop.getRegionIterArg(0);
388         Value newv = genLoad(builder, loc, posMemRef, i);
389         Value posZero = constantZero(builder, loc, posType);
390         Value cond = builder.create<arith::CmpIOp>(
391             loc, arith::CmpIPredicate::eq, newv, posZero);
392         scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType),
393                                                    cond, /*else*/ true);
394         builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
395         genStore(builder, loc, oldv, posMemRef, i);
396         builder.create<scf::YieldOp>(loc, oldv);
397         builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
398         builder.create<scf::YieldOp>(loc, newv);
399         builder.setInsertionPointAfter(ifOp);
400         builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
401         builder.setInsertionPointAfter(loop);
402       }
403     } else {
404       assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
405              isNOutOfMLT(lt));
406     }
407   }
408 }
409 
410 /// Generates a subview into the sizes.
411 static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
412                             Value sz) {
413   auto memTp = llvm::cast<MemRefType>(mem.getType());
414   // For higher-dimensional memrefs, we assume that the innermost
415   // dimension is always of the right size.
416   // TODO: generate complex truncating view here too?
417   if (memTp.getRank() > 1)
418     return mem;
419   // Truncate linear memrefs to given size.
420   return builder
421       .create<memref::SubViewOp>(
422           loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
423           mem, ValueRange{}, ValueRange{sz}, ValueRange{},
424           ArrayRef<int64_t>{0},                    // static offset
425           ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
426           ArrayRef<int64_t>{1})                    // static stride
427       .getResult();
428 }
429 
430 /// Creates the reassociation array.
431 static SmallVector<ReassociationIndices>
432 getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
433   SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
434   // Create reassociation in the form:
435   // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
436   for (unsigned i = 0; i < batchLvls; i++)
437     ret[i].push_back(i);
438 
439   for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
440     ret.back().push_back(i);
441 
442   return ret;
443 }
444 
445 //===----------------------------------------------------------------------===//
446 // Codegen rules.
447 //===----------------------------------------------------------------------===//
448 
449 namespace {
450 
451 /// Helper class to help lowering sparse_tensor.insert operation.
452 class SparseInsertGenerator
453     : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
454 public:
455   SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
456                         bool genCall)
457       : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){};
458 
459   /// Generates code along an insertion path without the need for a "cursor".
460   /// This current insertion strategy comes at the expense of some testing
461   /// overhead for each insertion. The strategy will be optimized later for
462   /// common insertion patterns. The current insertion strategy also assumes
463   /// insertions occur in "a reasonable order" that enables building the
464   /// storage scheme in an appending/inserting kind of fashion (i.e. no
465   /// in-between insertions that need data movement). The implementation
466   /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
467   ///
468   /// TODO: better unord/not-unique; also generalize, optimize, specialize!
469   SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
470                                        OpBuilder &builder, Location loc) {
471     const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
472     const Level lvlRank = stt.getLvlRank();
473     // Extract fields and coordinates from args.
474     SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
475     MutSparseTensorDescriptor desc(stt, fields);
476     const SmallVector<Value> coords =
477         llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
478     Value value = args.back();
479     Value parentPos = constantZero(builder, loc, builder.getIndexType());
480     // Generate code for every level.
481     for (Level lvl = 0; lvl < lvlRank; lvl++) {
482       const auto lt = stt.getLvlType(lvl);
483       if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
484         // Create:
485         //   if (!present) {
486         //     coordinates[lvl].push_back(coords[lvl])
487         //     <update positions and prepare level lvl + 1>
488         //   }
489         //   positions[lvl] = coordinates.size() - 1
490         //   <insert @ positions[lvl] at next level lvl + 1>
491         if (isLooseCompressedLT(lt)) {
492           Value two = constantIndex(builder, loc, 2);
493           parentPos = builder.create<arith::MulIOp>(loc, parentPos, two);
494         }
495         parentPos =
496             genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
497       } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
498         // Create:
499         //   coordinates[lvl].push_back(coords[lvl])
500         //   positions[lvl] = positions[lvl-1]
501         //   <insert @ positions[lvl] at next level lvl + 1>
502         createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef,
503                        lvl, /*value=*/coords[lvl]);
504       } else {
505         assert(isDenseLT(lt));
506         // Construct the new position as:
507         //   positions[lvl] = size * positions[lvl-1] + coords[lvl]
508         //   <insert @ positions[lvl] at next level lvl + 1>
509         Value size = desc.getLvlSize(builder, loc, lvl);
510         Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
511         parentPos = builder.create<arith::AddIOp>(loc, mult, coords[lvl]);
512       }
513     }
514     // Reached the actual value append/insert.
515     if (!stt.isDenseLvl(lvlRank - 1))
516       createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
517                      std::nullopt, value);
518     else
519       genStore(builder, loc, value, desc.getValMemRef(), parentPos);
520     return fields;
521   }
522 
523   std::string getMangledFuncName() {
524     // The mangled name of the function has this format:
525     //   <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
526     constexpr const char kInsertFuncNamePrefix[] = "_insert_";
527     const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
528     SmallString<32> nameBuffer;
529     llvm::raw_svector_ostream nameOstream(nameBuffer);
530     nameOstream << kInsertFuncNamePrefix;
531     const Level lvlRank = stt.getLvlRank();
532     for (Level l = 0; l < lvlRank; l++) {
533       std::string lvlType = toMLIRString(stt.getLvlType(l));
534       // Replace/remove punctuations in level properties.
535       std::replace_if(
536           lvlType.begin(), lvlType.end(),
537           [](char c) { return c == '(' || c == ','; }, '_');
538       llvm::erase_if(lvlType, [](char c) { return c == ')' || c == ' '; });
539       nameOstream << lvlType << "_";
540     }
541     // Static dim sizes are used in the generated code while dynamic sizes are
542     // loaded from the dimSizes buffer. This is the reason for adding the shape
543     // to the function name.
544     for (const auto sz : stt.getDimShape())
545       nameOstream << sz << "_";
546     // Permutation information is also used in generating insertion.
547     if (!stt.isIdentity())
548       nameOstream << stt.getDimToLvl() << "_";
549     nameOstream << stt.getElementType() << "_";
550     nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
551     return nameOstream.str().str();
552   }
553 
554 private:
555   TensorType rtp;
556 };
557 
558 /// Sparse tensor storage conversion rule for returns.
559 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
560 public:
561   using OpConversionPattern::OpConversionPattern;
562   LogicalResult
563   matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
564                   ConversionPatternRewriter &rewriter) const override {
565     // Create a return with the flattened value extracted from sparse tensors.
566     rewriter.replaceOpWithNewOp<func::ReturnOp>(
567         op, flattenValues(adaptor.getOperands()));
568     return success();
569   }
570 };
571 
572 /// Sparse tensor storage conversion rule for calls.
573 class SparseCallConverter : public OpConversionPattern<func::CallOp> {
574 public:
575   // The default CallOp converter can not handle 1:N type conversion.
576   using OpConversionPattern::OpConversionPattern;
577   LogicalResult
578   matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
579                   ConversionPatternRewriter &rewriter) const override {
580     Location loc = op.getLoc();
581     // In case of:
582     //  sparse_tensor, f, sparse_tensor = call @foo(...)
583     // ==>
584     //  memref..., f, memref = call @foo(...) replace with
585     //  cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
586     SmallVector<Type> finalRetTy;
587     if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
588       return failure();
589 
590     // (1) Generates new call with flattened return value.
591     auto newCall = rewriter.create<func::CallOp>(
592         loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands()));
593     // (2) Gather sparse tensor returns.
594     SmallVector<SmallVector<Value>> packedResultVals;
595     // Tracks the offset of current return value (of the original call)
596     // relative to the new call (after sparse tensor flattening);
597     unsigned retOffset = 0;
598     // Temporal buffer to hold the flattened list of type for
599     // a sparse tensor.
600     SmallVector<Type> sparseFlat;
601     for (auto ret : op.getResults()) {
602       assert(retOffset < newCall.getNumResults());
603       auto retType = ret.getType();
604       if (failed(typeConverter->convertType(retType, sparseFlat)))
605         llvm_unreachable("Failed to convert type in sparse tensor codegen");
606 
607       // Converted types can not be empty when the type conversion succeed.
608       assert(!sparseFlat.empty());
609       if (sparseFlat.size() > 1) {
610         auto flatSize = sparseFlat.size();
611         packedResultVals.emplace_back();
612         llvm::append_range(packedResultVals.back(),
613                            newCall.getResults().slice(retOffset, flatSize));
614         retOffset += flatSize;
615       } else {
616         // If this is an 1:1 conversion, no need for casting.
617         packedResultVals.emplace_back();
618         packedResultVals.back().push_back(newCall.getResult(retOffset));
619         retOffset++;
620       }
621       sparseFlat.clear();
622     }
623 
624     assert(packedResultVals.size() == op.getNumResults());
625     rewriter.replaceOpWithMultiple(
626         op, llvm::to_vector_of<ValueRange>(packedResultVals));
627     return success();
628   }
629 };
630 
631 /// Sparse codegen rule for level accesses.
632 class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
633 public:
634   using OpConversionPattern::OpConversionPattern;
635   LogicalResult
636   matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
637                   ConversionPatternRewriter &rewriter) const override {
638     std::optional<int64_t> lvl = op.getConstantLvlIndex();
639     RankedTensorType srcType = op.getSource().getType();
640     if (!lvl || !getSparseTensorEncoding(srcType))
641       return failure();
642 
643     auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
644     auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
645 
646     rewriter.replaceOp(op, sz);
647     return success();
648   }
649 };
650 
651 // TODO: use a new SortCOO operation here instead of reusing convert op.
652 struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
653   using OpConversionPattern::OpConversionPattern;
654   LogicalResult
655   matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
656                   ConversionPatternRewriter &rewriter) const override {
657     Location loc = op.getLoc();
658     MLIRContext *ctx = op.getContext();
659 
660     SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
661     SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
662 
663     // Should have been verified.
664     assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
665            dstStt.isCOOType() && srcStt.isCOOType());
666     assert(dstStt.hasSameDimToLvl(srcStt));
667 
668     // We don't need a mutable descriptor here as we perform sorting in-place.
669     auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
670                                              op.getInputCoo().getType());
671     auto nnz = desc.getValMemSize(rewriter, op.getLoc());
672     auto crd = desc.getAOSMemRef();
673     auto val = desc.getValMemRef();
674 
675     // Otherwise we need another data shuffle and a non-identity map.
676     assert(dstStt.hasSameDimToLvl(srcStt));
677     (void)dstStt; // to silence warning when assertion is disabled
678 
679     auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
680 
681     rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
682                             rewriter.getIndexAttr(0), op.getAlgorithm());
683 
684     // Since we do in-place sorting, the destinate tensor will have the same set
685     // of memrefs as the source tensor.
686     rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
687     return success();
688   }
689 };
690 
691 template <typename Op, StorageSpecifierKind kind>
692 class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
693 public:
694   using OpConversionPattern<Op>::OpConversionPattern;
695   using typename OpConversionPattern<Op>::OneToNOpAdaptor;
696 
697   LogicalResult
698   matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
699                   ConversionPatternRewriter &rewriter) const override {
700     // Simply lowers to specifer.get <field> operation.
701     auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
702                                              op.getSlice().getType());
703     auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
704                                     op.getDim().getZExtValue());
705 
706     rewriter.replaceOp(op, v);
707     return success();
708   }
709 };
710 
711 /// Sparse codegen rule for trivial tensor casts.
712 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
713 public:
714   using OpConversionPattern::OpConversionPattern;
715   LogicalResult
716   matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
717                   ConversionPatternRewriter &rewriter) const override {
718     // Only rewrite identically annotated source/dest.
719     auto encDst = getSparseTensorEncoding(op.getType());
720     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
721     if (!encDst || encDst != encSrc)
722       return failure();
723     rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
724     return success();
725   }
726 };
727 
728 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
729 public:
730   using OpConversionPattern::OpConversionPattern;
731   LogicalResult
732   matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
733                   ConversionPatternRewriter &rewriter) const override {
734     // Simply fold the operation.
735     rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
736     return success();
737   }
738 };
739 
740 /// Sparse codegen rule for the alloc operator.
741 class SparseTensorAllocConverter
742     : public OpConversionPattern<bufferization::AllocTensorOp> {
743 public:
744   using OpConversionPattern::OpConversionPattern;
745   SparseTensorAllocConverter(const TypeConverter &typeConverter,
746                              MLIRContext *context, bool enableInit)
747       : OpConversionPattern(typeConverter, context),
748         enableBufferInitialization(enableInit) {}
749 
750   LogicalResult
751   matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
752                   ConversionPatternRewriter &rewriter) const override {
753     const auto resType = getSparseTensorType(op);
754     if (!resType.hasEncoding())
755       return failure();
756 
757     Location loc = op.getLoc();
758     // Deal with copy.
759     if (op.getCopy()) {
760       auto desc = getDescriptorFromTensorTuple(
761           adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
762       SmallVector<Value> fields;
763       fields.reserve(desc.getNumFields());
764       // Memcpy on memref fields.
765       for (auto field : desc.getMemRefFields()) {
766         auto memrefTp = cast<MemRefType>(field.getType());
767         auto size = rewriter.create<memref::DimOp>(loc, field, 0);
768         auto copied =
769             rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size});
770         rewriter.create<memref::CopyOp>(loc, field, copied);
771         fields.push_back(copied);
772       }
773       // Reuses specifier.
774       fields.push_back(desc.getSpecifier());
775       assert(fields.size() == desc.getNumFields());
776       rewriter.replaceOpWithMultiple(op, {fields});
777       return success();
778     }
779 
780     if (!resType.isIdentity()) {
781       return rewriter.notifyMatchFailure(
782           op, "try run --sparse-reinterpret-map before codegen");
783     }
784     // Level size equals to dimension size since lvl2dim map is an identity map.
785     SmallVector<Value> lvlSizesValues;
786     createDimSizes(rewriter, loc, resType,
787                    flattenValues(adaptor.getDynamicSizes()),
788                    /*dimSizesValues=*/lvlSizesValues);
789 
790     // Construct allocation for each field.
791     Value sizeHint = op.getSizeHint();
792     SmallVector<Value> fields;
793     createAllocFields(rewriter, loc, resType, enableBufferInitialization,
794                       sizeHint, lvlSizesValues, fields);
795 
796     // Replace operation with resulting memrefs.
797     rewriter.replaceOpWithMultiple(op, {fields});
798     return success();
799   }
800 
801 private:
802   bool enableBufferInitialization;
803 };
804 
805 /// Sparse codegen rule for the empty tensor operator.
806 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
807 public:
808   using OpConversionPattern::OpConversionPattern;
809   SparseTensorEmptyConverter(const TypeConverter &typeConverter,
810                              MLIRContext *context, bool enableInit)
811       : OpConversionPattern(typeConverter, context),
812         enableBufferInitialization(enableInit) {}
813 
814   LogicalResult
815   matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
816                   ConversionPatternRewriter &rewriter) const override {
817     const auto resType = getSparseTensorType(op);
818     if (!resType.hasEncoding())
819       return failure();
820 
821     if (!resType.isIdentity()) {
822       return rewriter.notifyMatchFailure(
823           op, "try run --sparse-reinterpret-map before codegen");
824     }
825 
826     Location loc = op.getLoc();
827     // Level size equals to dimension size since lvl2dim map is an identity map.
828     SmallVector<Value> lvlSizesValues;
829     createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
830                    /*dimSizesValues=*/lvlSizesValues);
831     // Construct allocation for each field.
832     Value sizeHint; // none
833     SmallVector<Value> fields;
834     createAllocFields(rewriter, loc, resType, enableBufferInitialization,
835                       sizeHint, lvlSizesValues, fields);
836 
837     // Replace operation with resulting memrefs.
838     rewriter.replaceOpWithMultiple(op, {fields});
839     return success();
840   }
841 
842 private:
843   bool enableBufferInitialization;
844 };
845 
846 /// Sparse codegen rule for the dealloc operator.
847 class SparseTensorDeallocConverter
848     : public OpConversionPattern<bufferization::DeallocTensorOp> {
849 public:
850   using OpConversionPattern::OpConversionPattern;
851   SparseTensorDeallocConverter(const TypeConverter &typeConverter,
852                                MLIRContext *context, bool createDeallocs)
853       : OpConversionPattern(typeConverter, context),
854         createDeallocs(createDeallocs) {}
855 
856   LogicalResult
857   matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
858                   ConversionPatternRewriter &rewriter) const override {
859     auto enc = getSparseTensorEncoding(op.getTensor().getType());
860     if (!enc)
861       return failure();
862 
863     // If user requests not to deallocate sparse tensors, simply erase the
864     // operation.
865     if (createDeallocs) {
866       // Replace the sparse tensor deallocation with field deallocations.
867       Location loc = op.getLoc();
868       auto desc = getDescriptorFromTensorTuple(
869           adaptor.getTensor(),
870           cast<RankedTensorType>(op.getTensor().getType()));
871       for (auto input : desc.getMemRefFields())
872         // Deallocate every buffer used to store the sparse tensor handler.
873         rewriter.create<memref::DeallocOp>(loc, input);
874     }
875     rewriter.eraseOp(op);
876     return success();
877   }
878 
879 private:
880   const bool createDeallocs;
881 };
882 
883 /// Sparse codegen rule for tensor rematerialization.
884 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
885 public:
886   using OpConversionPattern::OpConversionPattern;
887   LogicalResult
888   matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
889                   ConversionPatternRewriter &rewriter) const override {
890     // Prepare descriptor.
891     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
892                                              op.getTensor().getType());
893     // Generate optional insertion finalization code.
894     if (op.getHasInserts())
895       genEndInsert(rewriter, op.getLoc(), desc);
896     // Replace operation with resulting memrefs.
897     rewriter.replaceOpWithMultiple(op, {desc.getFields()});
898     return success();
899   }
900 };
901 
902 /// Sparse codegen rule for the expand op.
903 class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
904 public:
905   using OpConversionPattern::OpConversionPattern;
906   LogicalResult
907   matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
908                   ConversionPatternRewriter &rewriter) const override {
909     if (!getSparseTensorEncoding(op.getTensor().getType()))
910       return failure();
911     Location loc = op->getLoc();
912     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
913                                              op.getTensor().getType());
914     const auto srcType = getSparseTensorType(op.getTensor());
915     Type eltType = srcType.getElementType();
916     Type boolType = rewriter.getIntegerType(1);
917     Type idxType = rewriter.getIndexType();
918     // All initialization should be done on entry of the loop nest.
919     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
920 
921     // Determine the size for access expansion (always the innermost stored
922     // level size).
923     const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
924     // Generate a memref for `sz` elements of type `t`.
925     const auto genAlloc = [&](Type t) {
926       const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
927       return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
928     };
929     // Allocate temporary buffers for values/filled-switch and added.
930     // We do not use stack buffers for this, since the expanded size may
931     // be rather large (as it envelops a single expanded dense dimension).
932     Value values = genAlloc(eltType);
933     Value filled = genAlloc(boolType);
934     Value added = genAlloc(idxType);
935     Value zero = constantZero(rewriter, loc, idxType);
936     // Reset the values/filled-switch to all-zero/false. Note that this
937     // introduces an O(N) operation into the computation, but this reset
938     // operation is amortized over the innermost loops for the access
939     // pattern expansion. As noted in the operation doc, we would like
940     // to amortize this setup cost even between kernels.
941     rewriter.create<linalg::FillOp>(
942         loc, ValueRange{constantZero(rewriter, loc, eltType)},
943         ValueRange{values});
944     rewriter.create<linalg::FillOp>(
945         loc, ValueRange{constantZero(rewriter, loc, boolType)},
946         ValueRange{filled});
947     // Replace expansion op with these buffers and initial coordinate.
948     assert(op.getNumResults() == 4);
949     rewriter.replaceOp(op, {values, filled, added, zero});
950     return success();
951   }
952 };
953 
954 /// Sparse codegen rule for the compress operator.
955 class SparseCompressConverter : public OpConversionPattern<CompressOp> {
956 public:
957   using OpConversionPattern::OpConversionPattern;
958   LogicalResult
959   matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
960                   ConversionPatternRewriter &rewriter) const override {
961     Location loc = op->getLoc();
962     SmallVector<Value> fields;
963     auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
964                                                 op.getTensor().getType());
965     Value values = getSingleValue(adaptor.getValues());
966     Value filled = getSingleValue(adaptor.getFilled());
967     Value added = getSingleValue(adaptor.getAdded());
968     Value count = getSingleValue(adaptor.getCount());
969     const SparseTensorType dstType(desc.getRankedTensorType());
970     Type eltType = dstType.getElementType();
971 
972     // If the innermost level is ordered, we need to sort the coordinates
973     // in the "added" array prior to applying the compression.
974     if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
975       rewriter.create<SortOp>(
976           loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1),
977           rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
978     // While performing the insertions, we also need to reset the elements
979     // of the values/filled-switch by only iterating over the set elements,
980     // to ensure that the runtime complexity remains proportional to the
981     // sparsity of the expanded access pattern.
982     //
983     // Generate
984     //    out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
985     //      crd = added[i];
986     //      value = values[crd];
987     //      insert({lvlCoords, crd}, value);
988     //      new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value);
989     //      values[crd] = 0;
990     //      filled[crd] = false;
991     //      yield new_memrefs
992     //    }
993     scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
994     Value i = loop.getInductionVar();
995 
996     Value crd = genLoad(rewriter, loc, added, i);
997     Value value = genLoad(rewriter, loc, values, crd);
998     SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
999     SmallVector<Type> flatSpTensorTps = llvm::to_vector(
1000         llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
1001     SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords());
1002     params.append(flatLvlCoords.begin(), flatLvlCoords.end());
1003     params.push_back(crd);
1004     params.push_back(value);
1005     SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1006                                     params, /*genCall=*/true);
1007     SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
1008     genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd);
1009     genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd);
1010     rewriter.create<scf::YieldOp>(loc, insertRet);
1011 
1012     rewriter.setInsertionPointAfter(loop);
1013     // Deallocate the buffers on exit of the full loop nest.
1014     Operation *parent = getTop(op);
1015     rewriter.setInsertionPointAfter(parent);
1016     rewriter.create<memref::DeallocOp>(loc, values);
1017     rewriter.create<memref::DeallocOp>(loc, filled);
1018     rewriter.create<memref::DeallocOp>(loc, added);
1019     // Replace operation with resulting memrefs.
1020     rewriter.replaceOpWithMultiple(op, {loop->getResults()});
1021     return success();
1022   }
1023 };
1024 
1025 /// Sparse codegen rule for the insert operator.
1026 class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
1027 public:
1028   using OpConversionPattern::OpConversionPattern;
1029   LogicalResult
1030   matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
1031                   ConversionPatternRewriter &rewriter) const override {
1032     auto stt = getSparseTensorType(op.getDest());
1033     if (!stt.hasEncoding())
1034       return failure();
1035     assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
1036 
1037     Location loc = op.getLoc();
1038     auto desc =
1039         getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
1040     TypeRange flatSpTensorTps = desc.getFields().getTypes();
1041     SmallVector<Value> params = llvm::to_vector(desc.getFields());
1042     SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
1043     params.append(flatIndices.begin(), flatIndices.end());
1044     params.push_back(getSingleValue(adaptor.getScalar()));
1045     SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1046                                     params, /*genCall=*/true);
1047     SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
1048     // Replace operation with resulting memrefs.
1049     rewriter.replaceOpWithMultiple(op, {ret});
1050     return success();
1051   }
1052 };
1053 
1054 /// Sparse codegen rule for position accesses.
1055 class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
1056 public:
1057   using OpAdaptor = typename ToPositionsOp::Adaptor;
1058   using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
1059   LogicalResult
1060   matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
1061                   ConversionPatternRewriter &rewriter) const override {
1062     // Replace the requested position access with corresponding field.
1063     // The view is restricted to the actual size to ensure clients
1064     // of this operation truly observe size, not capacity!
1065     Location loc = op.getLoc();
1066     Level lvl = op.getLevel();
1067     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1068                                              op.getTensor().getType());
1069     auto mem = desc.getPosMemRef(lvl);
1070     auto size = desc.getPosMemSize(rewriter, loc, lvl);
1071     rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1072     return success();
1073   }
1074 };
1075 
1076 /// Sparse codegen rule for accessing the coordinates arrays.
1077 class SparseToCoordinatesConverter
1078     : public OpConversionPattern<ToCoordinatesOp> {
1079 public:
1080   using OpAdaptor = typename ToCoordinatesOp::Adaptor;
1081   using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
1082   LogicalResult
1083   matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
1084                   ConversionPatternRewriter &rewriter) const override {
1085     // Replace the requested coordinates access with corresponding field.
1086     // The view is restricted to the actual size to ensure clients
1087     // of this operation truly observe size, not capacity!
1088     Location loc = op.getLoc();
1089     Level lvl = op.getLevel();
1090     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1091                                              op.getTensor().getType());
1092     auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1093     if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
1094       auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1095       mem = genSliceToSize(rewriter, loc, mem, size);
1096     }
1097     rewriter.replaceOp(op, mem);
1098     return success();
1099   }
1100 };
1101 
1102 /// Sparse codegen rule for accessing the linear coordinates buffer.
1103 class SparseToCoordinatesBufferConverter
1104     : public OpConversionPattern<ToCoordinatesBufferOp> {
1105 public:
1106   using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
1107   using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
1108   LogicalResult
1109   matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
1110                   ConversionPatternRewriter &rewriter) const override {
1111     // Replace the requested coordinates access with corresponding field.
1112     // The view is restricted to the actual size to ensure clients
1113     // of this operation truly observe size, not capacity!
1114     Location loc = op.getLoc();
1115     Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
1116     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1117                                              op.getTensor().getType());
1118     auto mem = desc.getAOSMemRef();
1119     auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1120     rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1121     return success();
1122   }
1123 };
1124 
1125 /// Sparse codegen rule for value accesses.
1126 class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
1127 public:
1128   using OpAdaptor = typename ToValuesOp::Adaptor;
1129   using OpConversionPattern<ToValuesOp>::OpConversionPattern;
1130   LogicalResult
1131   matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
1132                   ConversionPatternRewriter &rewriter) const override {
1133     // Replace the requested values access with corresponding field.
1134     // The view is restricted to the actual size to ensure clients
1135     // of this operation truly observe size, not capacity!
1136     Location loc = op.getLoc();
1137     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1138                                              op.getTensor().getType());
1139     auto mem = desc.getValMemRef();
1140     auto size = desc.getValMemSize(rewriter, loc);
1141     rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1142     return success();
1143   }
1144 };
1145 
1146 /// Sparse codegen rule for the convert operator.
1147 class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
1148 public:
1149   using OpConversionPattern::OpConversionPattern;
1150   LogicalResult
1151   matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
1152                   ConversionPatternRewriter &rewriter) const override {
1153     SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
1154     SparseTensorEncodingAttr encSrc =
1155         getSparseTensorEncoding(op.getSource().getType());
1156     // The output tensor can not be a slice and those cases should have been
1157     // rejected by ConvertOp::verify() already.
1158     assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices.");
1159     // Different encoding (except for different bitwidth) should be handled by
1160     // rewriting.
1161     // We need further rewrites if the input tensor is a slice too.
1162     if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1163         encSrc.isSlice()) {
1164       return failure();
1165     }
1166 
1167     Type retElemTp = op.getResult().getType().getElementType();
1168     Type srcElemTp = op.getSource().getType().getElementType();
1169     // Fold the trivial cases.
1170     if (retElemTp == srcElemTp && encDst == encSrc) {
1171       rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
1172       return success();
1173     }
1174     //
1175     // Do element-wise type conversion without using InsertOp.
1176     //
1177     // for each memref in srcTensor:
1178     //   dst = memref.alloc
1179     //   if srcMemRefType != dstMemRefType:
1180     //     for every dst[i] = cast(src[i])
1181     //   else:
1182     //     dst = memref.copy(src)
1183     Location loc = op.getLoc();
1184     auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
1185                                                 op.getSource().getType());
1186     SmallVector<Value> fields;
1187     foreachFieldAndTypeInSparseTensor(
1188         SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
1189         [&rewriter, &fields, srcDesc,
1190          loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
1191               LevelType /*lt*/) -> bool {
1192           // Simply reuses the storage specifier as it is an SSA value.
1193           if (fKind == SparseTensorFieldKind::StorageSpec) {
1194             fields.push_back(srcDesc.getSpecifier());
1195           } else {
1196             // Allocates new memrefs
1197             Value srcMem = srcDesc.getMemRefField(fIdx);
1198             // TODO: We can instead use the actual memSize in specifier, that
1199             // would require a subViewOp to avoid overflow when copying
1200             // values.
1201             Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1202             auto dstMem = rewriter.create<memref::AllocOp>(
1203                 loc, cast<MemRefType>(fTp), sz);
1204             if (fTp != srcMem.getType()) {
1205               // Converts elements type.
1206               scf::buildLoopNest(
1207                   rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1208                   constantIndex(rewriter, loc, 1),
1209                   [srcMem, &dstMem](OpBuilder &builder, Location loc,
1210                                     ValueRange ivs) {
1211                     Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
1212                     Value casted = genCast(builder, loc, v,
1213                                            dstMem.getType().getElementType());
1214                     builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
1215                   });
1216             } else {
1217               // TODO: We can even reuse the same memref for the new tensor,
1218               // but that requires a `ref-counting` based memory management
1219               // for shared memrefs between multiple sparse tensors.
1220               rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1221             }
1222             fields.push_back(dstMem);
1223           }
1224           return true;
1225         });
1226 
1227     rewriter.replaceOpWithMultiple(op, {fields});
1228     return success();
1229   }
1230 };
1231 
1232 class SparseExtractSliceConverter
1233     : public OpConversionPattern<tensor::ExtractSliceOp> {
1234 public:
1235   using OpConversionPattern::OpConversionPattern;
1236   LogicalResult
1237   matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
1238                   ConversionPatternRewriter &rewriter) const override {
1239     Location loc = op.getLoc();
1240     MLIRContext *ctx = op.getContext();
1241     auto srcEnc = getSparseTensorEncoding(op.getSourceType());
1242     auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
1243     // TODO: We should check these in ExtractSliceOp::verify.
1244     if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1245       return failure();
1246     assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1247 
1248     SmallVector<Value> fields;
1249     auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
1250                                                 op.getSource().getType());
1251 
1252     auto newSpec = rewriter.create<StorageSpecifierInitOp>(
1253         loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
1254     desc.setSpecifier(newSpec);
1255 
1256     // Fills in slice information.
1257     for (auto [idx, offset, size, stride] : llvm::enumerate(
1258              op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1259       Dimension dim = idx;
1260 
1261       Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
1262       Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1263       Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride);
1264       // TODO: We could probably only set dynamic value here. But it would
1265       // requires us to fill the hole when casting a static slice to dynamic
1266       // slice.
1267       desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset,
1268                              dim, offsetV);
1269 
1270       // FIXME: we need to distinguish level sizes and dimension size for slices
1271       // here. Maybe we should store slice level sizes in a different array
1272       // instead of reusing it.
1273       assert(srcEnc.isIdentity());
1274       desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
1275                              sizeV);
1276       desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
1277                              dim, strideV);
1278     }
1279 
1280     // NOTE: we can not generate tuples directly from descriptor here, as the
1281     // descriptor is holding the original type, yet we want the slice type
1282     // here (they shared every memref but with an updated specifier).
1283     rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1284     return success();
1285   }
1286 };
1287 
1288 /// Sparse codegen rule for number of entries operator.
1289 class SparseNumberOfEntriesConverter
1290     : public OpConversionPattern<NumberOfEntriesOp> {
1291 public:
1292   using OpConversionPattern::OpConversionPattern;
1293   LogicalResult
1294   matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
1295                   ConversionPatternRewriter &rewriter) const override {
1296     // Query memSizes for the actually stored values.
1297     // FIXME: the nse value computed in this way might be wrong when there is
1298     // any "loose_compressed" level.
1299     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1300                                              op.getTensor().getType());
1301     rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
1302     return success();
1303   }
1304 };
1305 
1306 struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
1307   using OpConversionPattern::OpConversionPattern;
1308   LogicalResult
1309   matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1310                   ConversionPatternRewriter &rewriter) const override {
1311     Location loc = op.getLoc();
1312     const auto stt = getSparseTensorType(op.getResult());
1313 
1314     SmallVector<Value> fields;
1315 
1316     foreachFieldAndTypeInSparseTensor(
1317         stt,
1318         [&rewriter, &fields, &op, &stt,
1319          loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
1320               Level /*lvl*/, LevelType lt) -> bool {
1321           assert(fields.size() == fIdx);
1322           if (fKind == SparseTensorFieldKind::StorageSpec) {
1323             fields.push_back(
1324                 SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1325           } else {
1326             // Else simply takes the inputs.
1327             Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1328                                ? op.getValues()
1329                                : op.getLevels()[fIdx];
1330             // TODO: handle batch.
1331             TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
1332             if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1333               // Flattens the buffer to batchLvlRank.
1334               auto reassoc = getReassociationForFlattening(
1335                   mem.getType(), stt.getBatchLvlRank());
1336               mem = rewriter.create<memref::CastOp>(
1337                   loc, fType,
1338                   rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
1339             } else {
1340               mem = rewriter.create<memref::CastOp>(loc, fType, mem);
1341             }
1342             fields.push_back(mem);
1343           }
1344           return true;
1345         });
1346 
1347     MutSparseTensorDescriptor desc(stt, fields);
1348     Value c0 = constantIndex(rewriter, loc, 0);
1349     Value c1 = constantIndex(rewriter, loc, 1);
1350     Value c2 = constantIndex(rewriter, loc, 2);
1351     Value posBack = c0; // index to the last value in the position array
1352     Value memSize = c1; // memory size for current array
1353 
1354     Level trailCOOStart = stt.getAoSCOOStart();
1355     Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1356     // Sets up SparseTensorSpecifier.
1357     for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1358       assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1359 
1360       // Sets up the level size.
1361       auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1362       desc.setLvlSize(rewriter, loc, lvl, lvlSize);
1363       // We use a single AOS array to store the trailing COO, so there is only
1364       // one memory size to set for the entire COO section.
1365       if (lvl > trailCOOStart)
1366         continue;
1367 
1368       // Sets up the memory size by reading the last value in position array.
1369       LevelType lt = stt.getLvlType(lvl);
1370       // Simply forwards the position index when this is a dense level.
1371       if (lt.isa<LevelFormat::Dense>()) {
1372         memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
1373         posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1374         continue;
1375       }
1376       if (lt.isa<LevelFormat::Batch>()) {
1377         // Skips batch levels as it is not linearized.
1378         // FIXME: this assumes that every batch has the same number of nse, need
1379         // to be generalized to handle varied-size batches.
1380         continue;
1381       }
1382 
1383       if (isWithPosLT(lt)) {
1384         assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
1385         if (isLooseCompressedLT(lt)) {
1386           memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
1387           posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1388         } else {
1389           assert(isCompressedLT(lt));
1390           posBack = memSize;
1391           memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1);
1392         }
1393         desc.setPosMemSize(rewriter, loc, lvl, memSize);
1394         // The last value in position array is the memory size for next level.
1395         // FIXME: this assumes that every batch has the same number of nse, need
1396         // to be generalized to handle varied-size batches.
1397         SmallVector<Value> batched(stt.getBatchLvlRank(),
1398                                    constantIndex(rewriter, loc, 0));
1399         batched.push_back(posBack);
1400         memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
1401         posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
1402       }
1403       assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
1404       // FIXME: This seems to be unnecessarily complex, can we simplify it?
1405       if (lvl == trailCOOStart) {
1406         Value cooSz = rewriter.create<arith::MulIOp>(
1407             loc, memSize, constantIndex(rewriter, loc, trailCOORank));
1408         desc.setCrdMemSize(rewriter, loc, lvl, cooSz);
1409       } else {
1410         desc.setCrdMemSize(rewriter, loc, lvl, memSize);
1411       }
1412     }
1413     desc.setValMemSize(rewriter, loc, memSize);
1414 
1415     rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1416     return success();
1417   }
1418 };
1419 
1420 struct SparseDisassembleOpConverter
1421     : public OpConversionPattern<DisassembleOp> {
1422   using OpConversionPattern::OpConversionPattern;
1423   SparseDisassembleOpConverter(const TypeConverter &typeConverter,
1424                                MLIRContext *context)
1425       : OpConversionPattern(typeConverter, context) {}
1426 
1427   LogicalResult
1428   matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
1429                   ConversionPatternRewriter &rewriter) const override {
1430     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1431                                              op.getTensor().getType());
1432     Location loc = op.getLoc();
1433     SmallVector<Value> retMem;
1434     SmallVector<Value> retLen;
1435     desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem,
1436                                    &retLen](FieldIndex fid,
1437                                             SparseTensorFieldKind fKind,
1438                                             Level lvl, LevelType lt) -> bool {
1439       if (fKind == SparseTensorFieldKind::StorageSpec)
1440         return true;
1441       SparseTensorType stt(desc.getRankedTensorType());
1442       Value sz, src;
1443       TypedValue<BaseMemRefType> dst;
1444       if (fKind == SparseTensorFieldKind::ValMemRef) {
1445         sz = desc.getValMemSize(rewriter, loc);
1446         src = desc.getValMemRef();
1447         dst = genToMemref(rewriter, loc, op.getOutValues());
1448 
1449         retMem.push_back(dst);
1450         Type valLenTp = op.getValLen().getType();
1451         retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
1452       } else {
1453         assert(fKind == SparseTensorFieldKind::PosMemRef ||
1454                fKind == SparseTensorFieldKind::CrdMemRef);
1455 
1456         sz = fKind == SparseTensorFieldKind::PosMemRef
1457                  ? desc.getPosMemSize(rewriter, loc, lvl)
1458                  : desc.getCrdMemSize(rewriter, loc, lvl);
1459         src = desc.getMemRefField(fid);
1460         dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1461         retMem.push_back(dst);
1462         // Retrieves the corresponding level length type.
1463         Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1464         retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
1465       }
1466       Value flatOut = dst;
1467       if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1468         auto reassoc =
1469             getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
1470         flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
1471       }
1472       Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
1473       Value srcMem = genSliceToSize(rewriter, loc, src, sz);
1474       rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1475       return true;
1476     });
1477 
1478     // Converts MemRefs back to Tensors.
1479     SmallVector<Value> retValues = llvm::to_vector(
1480         llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
1481           return rewriter.create<bufferization::ToTensorOp>(loc, v);
1482         }));
1483     // Appends the actual memory length used in each buffer returned.
1484     retValues.append(retLen.begin(), retLen.end());
1485     rewriter.replaceOp(op, retValues);
1486     return success();
1487   }
1488 };
1489 
1490 struct SparseNewConverter : public OpConversionPattern<NewOp> {
1491   using OpConversionPattern::OpConversionPattern;
1492   LogicalResult
1493   matchAndRewrite(NewOp op, OpAdaptor adaptor,
1494                   ConversionPatternRewriter &rewriter) const override {
1495     Location loc = op.getLoc();
1496     const auto dstTp = getSparseTensorType(op.getResult());
1497     // Creating COO with NewOp is handled by direct IR codegen. All other cases
1498     // are handled by rewriting.
1499     if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1500       return failure();
1501 
1502     // Implement as follows:
1503     //   %reader = @createCheckedSparseTensorReader(%filename)
1504     //   %nse = @getSparseTensorNSE(%reader)
1505     //   %coo = bufferization.alloc_tensor an ordered COO with
1506     //          dst dim ordering, size_hint = %nse
1507     //   %coordinates = sparse_tensor.coordinates_buffer(%coo)
1508     //   %values = sparse_tensor.values(%coo)
1509     //   %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values)
1510     //   if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
1511     //   update storage specifier
1512     //   @delSparseTensorReader(%reader)
1513     SmallVector<Value> dimSizesValues;
1514     Value dimSizesBuffer;
1515     Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1516                              dimSizesValues, dimSizesBuffer);
1517 
1518     // Get the number of stored entries.
1519     const Type indexTp = rewriter.getIndexType();
1520     Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
1521                                {indexTp}, {reader}, EmitCInterface::Off)
1522                     .getResult(0);
1523 
1524     // Construct the lvl sizes and the dim2lvl/lvl2dim buffers.
1525     SmallVector<Value> lvlSizesValues;
1526     Value dim2lvlBuffer;
1527     Value lvl2dimBuffer;
1528     genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1529                   lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1530 
1531     // Construct allocation for each field.
1532     Value sizeHint = nse;
1533     SmallVector<Value> fields;
1534     createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint,
1535                       lvlSizesValues, fields);
1536 
1537     // Read the COO tensor data.
1538     MutSparseTensorDescriptor desc(dstTp, fields);
1539     Value xs = desc.getAOSMemRef();
1540     Value ys = desc.getValMemRef();
1541     const Type boolTp = rewriter.getIntegerType(1);
1542     const Type elemTp = dstTp.getElementType();
1543     const Type crdTp = dstTp.getCrdType();
1544     SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers",
1545                                           overheadTypeFunctionSuffix(crdTp),
1546                                           primaryTypeFunctionSuffix(elemTp)};
1547     Value isSorted =
1548         createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
1549                        {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1550                        EmitCInterface::On)
1551             .getResult(0);
1552 
1553     // If the destination tensor is a sorted COO, we need to sort the COO tensor
1554     // data if the input elements aren't sorted yet.
1555     const Level lvlRank = dstTp.getLvlRank();
1556     if (dstTp.isOrderedLvl(lvlRank - 1)) {
1557       Value kFalse = constantI1(rewriter, loc, false);
1558       Value notSorted = rewriter.create<arith::CmpIOp>(
1559           loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1560       scf::IfOp ifOp =
1561           rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
1562       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1563       auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
1564       rewriter.create<SortOp>(loc, nse, xs, ValueRange{ys}, xPerm,
1565                               rewriter.getIndexAttr(0),
1566                               SparseTensorSortKind::HybridQuickSort);
1567       rewriter.setInsertionPointAfter(ifOp);
1568     }
1569 
1570     // Set PosMemRef0[1] = nse.
1571     const Value c1 = constantIndex(rewriter, loc, 1);
1572     const Value posMemref0 = desc.getPosMemRef(0);
1573     const Type posTp = dstTp.getPosType();
1574     const Value posNse = genCast(rewriter, loc, nse, posTp);
1575     rewriter.create<memref::StoreOp>(loc, posNse, posMemref0, c1);
1576 
1577     // Update storage specifier.
1578     Value coordinatesSize = rewriter.create<arith::MulIOp>(
1579         loc, nse, constantIndex(rewriter, loc, lvlRank));
1580     desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0,
1581                            coordinatesSize);
1582     desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
1583                            std::nullopt, nse);
1584 
1585     // Release the sparse tensor reader.
1586     createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
1587                    EmitCInterface::Off);
1588 
1589     // Replace operation with resulting memrefs.
1590     rewriter.replaceOpWithMultiple(op, {fields});
1591     return success();
1592   }
1593 };
1594 
1595 struct SparseHasRuntimeLibraryConverter
1596     : public OpConversionPattern<HasRuntimeLibraryOp> {
1597   using OpConversionPattern::OpConversionPattern;
1598   LogicalResult
1599   matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1600                   ConversionPatternRewriter &rewriter) const override {
1601     auto i1Type = rewriter.getI1Type();
1602     rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1603         op, i1Type, rewriter.getIntegerAttr(i1Type, 0));
1604     return success();
1605   }
1606 };
1607 
1608 } // namespace
1609 
1610 //===----------------------------------------------------------------------===//
1611 // Public method for populating conversion rules.
1612 //===----------------------------------------------------------------------===//
1613 
1614 /// Populates the given patterns list with conversion rules required for
1615 /// the sparsification of linear algebra operations.
1616 void mlir::populateSparseTensorCodegenPatterns(
1617     const TypeConverter &typeConverter, RewritePatternSet &patterns,
1618     bool createSparseDeallocs, bool enableBufferInitialization) {
1619   patterns.add<
1620       SparseAssembleOpConverter, SparseDisassembleOpConverter,
1621       SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1622       SparseCastConverter, SparseExtractSliceConverter,
1623       SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1624       SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1625       SparseSliceGetterOpConverter<ToSliceOffsetOp,
1626                                    StorageSpecifierKind::DimOffset>,
1627       SparseSliceGetterOpConverter<ToSliceStrideOp,
1628                                    StorageSpecifierKind::DimStride>,
1629       SparseToPositionsConverter, SparseToCoordinatesConverter,
1630       SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1631       SparseConvertConverter, SparseNewConverter,
1632       SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1633       typeConverter, patterns.getContext());
1634   patterns.add<SparseTensorDeallocConverter>(
1635       typeConverter, patterns.getContext(), createSparseDeallocs);
1636   patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1637       typeConverter, patterns.getContext(), enableBufferInitialization);
1638 }
1639