xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (revision 9df63b2651b2435c02a7d825953ca2ddc65c778e)
186b22d31SAart Bik //===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
286b22d31SAart Bik //
386b22d31SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
486b22d31SAart Bik // See https://llvm.org/LICENSE.txt for license information.
586b22d31SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
686b22d31SAart Bik //
786b22d31SAart Bik //===----------------------------------------------------------------------===//
886b22d31SAart Bik //
986b22d31SAart Bik // A pass that converts sparse tensor types and primitives to actual compiler
1086b22d31SAart Bik // visible buffers and actual compiler IR that implements these primitives on
1186b22d31SAart Bik // the selected sparse tensor storage schemes. This pass provides an alternative
1286b22d31SAart Bik // to the SparseTensorConversion pass, eliminating the dependence on a runtime
13bc61122aSAart Bik // support library (other than for file I/O), and providing many more
14bc61122aSAart Bik // opportunities for subsequent compiler optimization of the generated code.
1586b22d31SAart Bik //
1686b22d31SAart Bik //===----------------------------------------------------------------------===//
1786b22d31SAart Bik 
18365777ecSAart Bik #include "Utils/CodegenUtils.h"
19365777ecSAart Bik #include "Utils/SparseTensorDescriptor.h"
2086b22d31SAart Bik 
216db397a8SPeiming Liu #include "mlir/Dialect/Arith/Utils/Utils.h"
222ddfacd9SAart Bik #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2386b22d31SAart Bik #include "mlir/Dialect/Func/IR/FuncOps.h"
248a583bd5Sbixia1 #include "mlir/Dialect/Linalg/Utils/Utils.h"
2586b22d31SAart Bik #include "mlir/Dialect/MemRef/IR/MemRef.h"
26840e2ba3Sbixia1 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
2786b22d31SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
28f708a549Swren romano #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
2986b22d31SAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
3086b22d31SAart Bik #include "mlir/Dialect/Tensor/IR/Tensor.h"
3186b22d31SAart Bik #include "mlir/Transforms/DialectConversion.h"
32f708a549Swren romano 
33a1fe1f5fSKazu Hirata #include <optional>
3486b22d31SAart Bik 
3586b22d31SAart Bik using namespace mlir;
3686b22d31SAart Bik using namespace mlir::sparse_tensor;
3786b22d31SAart Bik 
3886b22d31SAart Bik //===----------------------------------------------------------------------===//
3986b22d31SAart Bik // Helper methods.
4086b22d31SAart Bik //===----------------------------------------------------------------------===//
4186b22d31SAart Bik 
42*9df63b26SMatthias Springer /// Flatten the given value ranges into a single vector of values.
43*9df63b26SMatthias Springer static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
44*9df63b26SMatthias Springer   SmallVector<Value> result;
45*9df63b26SMatthias Springer   for (const auto &vals : values)
46*9df63b26SMatthias Springer     llvm::append_range(result, vals);
47*9df63b26SMatthias Springer   return result;
48edca72f5SPeiming Liu }
49*9df63b26SMatthias Springer 
50*9df63b26SMatthias Springer /// Assert that the given value range contains a single value and return it.
51*9df63b26SMatthias Springer static Value getSingleValue(ValueRange values) {
52*9df63b26SMatthias Springer   assert(values.size() == 1 && "expected single value");
53*9df63b26SMatthias Springer   return values.front();
545ab1a8aeSPeiming Liu }
55edca72f5SPeiming Liu 
5684cd51bbSwren romano /// Generates a load with proper `index` typing.
5770633a8dSAart Bik static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
5844ff23d5SPeiming Liu   idx = genCast(builder, loc, idx, builder.getIndexType());
5970633a8dSAart Bik   return builder.create<memref::LoadOp>(loc, mem, idx);
6070633a8dSAart Bik }
6170633a8dSAart Bik 
6284cd51bbSwren romano /// Generates a store with proper `index` typing and proper value.
6370633a8dSAart Bik static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
6470633a8dSAart Bik                      Value idx) {
6544ff23d5SPeiming Liu   idx = genCast(builder, loc, idx, builder.getIndexType());
6644ff23d5SPeiming Liu   val = genCast(builder, loc, val,
675550c821STres Popp                 cast<ShapedType>(mem.getType()).getElementType());
6870633a8dSAart Bik   builder.create<memref::StoreOp>(loc, val, mem, idx);
6970633a8dSAart Bik }
7070633a8dSAart Bik 
7170633a8dSAart Bik /// Creates a straightforward counting for-loop.
7270633a8dSAart Bik static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
73191c43f6SPeiming Liu                             MutableArrayRef<Value> fields,
7470633a8dSAart Bik                             Value lower = Value()) {
7570633a8dSAart Bik   Type indexType = builder.getIndexType();
7670633a8dSAart Bik   if (!lower)
7770633a8dSAart Bik     lower = constantZero(builder, loc, indexType);
7870633a8dSAart Bik   Value one = constantOne(builder, loc, indexType);
7970633a8dSAart Bik   scf::ForOp forOp = builder.create<scf::ForOp>(loc, lower, upper, one, fields);
8070633a8dSAart Bik   for (unsigned i = 0, e = fields.size(); i < e; i++)
8170633a8dSAart Bik     fields[i] = forOp.getRegionIterArg(i);
8270633a8dSAart Bik   builder.setInsertionPointToStart(forOp.getBody());
8370633a8dSAart Bik   return forOp;
8470633a8dSAart Bik }
8570633a8dSAart Bik 
86bc61122aSAart Bik /// Creates a push back operation.
8770633a8dSAart Bik static void createPushback(OpBuilder &builder, Location loc,
88988733c6SPeiming Liu                            MutSparseTensorDescriptor desc,
89f708a549Swren romano                            SparseTensorFieldKind kind, std::optional<Level> lvl,
90f708a549Swren romano                            Value value, Value repeat = Value()) {
91f708a549Swren romano   Type etp = desc.getMemRefElementType(kind, lvl);
92f708a549Swren romano   Value field = desc.getMemRefField(kind, lvl);
93988733c6SPeiming Liu   StorageSpecifierKind specFieldKind = toSpecifierKind(kind);
946607fdf7SAart Bik 
95988733c6SPeiming Liu   auto pushBackOp = builder.create<PushBackOp>(
96f708a549Swren romano       loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field,
9744ff23d5SPeiming Liu       genCast(builder, loc, value, etp), repeat);
98191c43f6SPeiming Liu 
99f708a549Swren romano   desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
100f708a549Swren romano   desc.setSpecifierField(builder, loc, specFieldKind, lvl,
101988733c6SPeiming Liu                          pushBackOp.getNewSize());
1023ae98fd2SAart Bik }
1033ae98fd2SAart Bik 
10470633a8dSAart Bik /// Generates code that allocates a sparse storage scheme for given rank.
10570633a8dSAart Bik static void allocSchemeForRank(OpBuilder &builder, Location loc,
106f708a549Swren romano                                MutSparseTensorDescriptor desc, Level startLvl) {
107f708a549Swren romano   const SparseTensorType stt(desc.getRankedTensorType());
10870633a8dSAart Bik   Value linear = constantIndex(builder, loc, 1);
109f708a549Swren romano   const Level lvlRank = stt.getLvlRank();
110160d483bSAart Bik   for (Level lvl = startLvl; lvl < lvlRank; lvl++) {
1111dd387e1SAart Bik     const auto lt = stt.getLvlType(lvl);
1121dd387e1SAart Bik     if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
11384cd51bbSwren romano       // Append linear x positions, initialized to zero. Since each compressed
11470633a8dSAart Bik       // dimension initially already has a single zero entry, this maintains
115160d483bSAart Bik       // the desired "linear + 1" length property at all times. For loose
116160d483bSAart Bik       // compression, we multiply linear by two in order to append both the
117160d483bSAart Bik       // lo/hi positions.
11884cd51bbSwren romano       Value posZero = constantZero(builder, loc, stt.getPosType());
1191dd387e1SAart Bik       if (isLooseCompressedLT(lt)) {
120160d483bSAart Bik         Value two = constantIndex(builder, loc, 2);
121160d483bSAart Bik         linear = builder.create<arith::MulIOp>(loc, linear, two);
122ef99c27dSMehdi Amini       }
123160d483bSAart Bik       createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
124160d483bSAart Bik                      /*value=*/posZero, /*repeat=*/linear);
125160d483bSAart Bik       return;
126e5924d64SYinying Li     } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
12770633a8dSAart Bik       return; // nothing to do
128dd33481fSPeiming Liu     }
129dd33481fSPeiming Liu     // Keep compounding the size, but nothing needs to be initialized
13070633a8dSAart Bik     // at this level. We will eventually reach a compressed level or
13170633a8dSAart Bik     // otherwise the values array for the from-here "all-dense" case.
1321dd387e1SAart Bik     assert(isDenseLT(lt));
133160d483bSAart Bik     Value size = desc.getLvlSize(builder, loc, lvl);
13470633a8dSAart Bik     linear = builder.create<arith::MulIOp>(loc, linear, size);
13570633a8dSAart Bik   }
13670633a8dSAart Bik   // Reached values array so prepare for an insertion.
137f708a549Swren romano   Value valZero = constantZero(builder, loc, stt.getElementType());
138988733c6SPeiming Liu   createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
139160d483bSAart Bik                  std::nullopt, /*value=*/valZero, /*repeat=*/linear);
14070633a8dSAart Bik }
14170633a8dSAart Bik 
14280b08b68SAart Bik /// Creates allocation operation.
143191c43f6SPeiming Liu static Value createAllocation(OpBuilder &builder, Location loc,
144191c43f6SPeiming Liu                               MemRefType memRefType, Value sz,
145191c43f6SPeiming Liu                               bool enableInit) {
146191c43f6SPeiming Liu   Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
147191c43f6SPeiming Liu   Type elemType = memRefType.getElementType();
1487276b643Sbixia1   if (enableInit) {
149ea4be70cSbixia1     Value fillValue = constantZero(builder, loc, elemType);
1507276b643Sbixia1     builder.create<linalg::FillOp>(loc, fillValue, buffer);
1517276b643Sbixia1   }
1527276b643Sbixia1   return buffer;
1530c7abd39SAart Bik }
1540c7abd39SAart Bik 
15583cf0dc9SAart Bik /// Creates the dim sizes array, filling in from dynamic sizes.
15683cf0dc9SAart Bik static void createDimSizes(OpBuilder &builder, Location loc,
15783cf0dc9SAart Bik                            SparseTensorType stt, ValueRange dynSizes,
15883cf0dc9SAart Bik                            /*out*/ SmallVectorImpl<Value> &dimSizesValues) {
15983cf0dc9SAart Bik   const Dimension dimRank = stt.getDimRank();
16083cf0dc9SAart Bik   dimSizesValues.clear();
16183cf0dc9SAart Bik   dimSizesValues.reserve(dimRank);
16283cf0dc9SAart Bik   unsigned i = 0;
16383cf0dc9SAart Bik   for (const Size sz : stt.getDimShape())
16483cf0dc9SAart Bik     dimSizesValues.push_back(ShapedType::isDynamic(sz)
16583cf0dc9SAart Bik                                  ? dynSizes[i++]
16683cf0dc9SAart Bik                                  : constantIndex(builder, loc, sz));
16783cf0dc9SAart Bik }
16883cf0dc9SAart Bik 
1696607fdf7SAart Bik /// Creates allocation for each field in sparse tensor type. Note that
170160d483bSAart Bik /// for all dynamic memrefs in the sparse tensor stroage layout, the
171160d483bSAart Bik /// memory size is really the capacity of the "vector", while the actual
172160d483bSAart Bik /// size resides in the sizes array.
173f708a549Swren romano static void createAllocFields(OpBuilder &builder, Location loc,
17483cf0dc9SAart Bik                               SparseTensorType stt, bool enableInit,
17583cf0dc9SAart Bik                               Value sizeHint,
17683cf0dc9SAart Bik                               SmallVectorImpl<Value> &lvlSizesValues,
17783cf0dc9SAart Bik                               /*out*/ SmallVectorImpl<Value> &fields) {
17883cf0dc9SAart Bik   Level lvlRank = stt.getLvlRank();
179e2e6e7a6SAart Bik   // Set up some heuristic sizes. We try to set the initial
180e2e6e7a6SAart Bik   // size based on available information. Otherwise we just
181e2e6e7a6SAart Bik   // initialize a few elements to start the reallocation chain.
182e2e6e7a6SAart Bik   // TODO: refine this
18384cd51bbSwren romano   Value posHeuristic, crdHeuristic, valHeuristic;
184f708a549Swren romano   if (stt.isAllDense()) {
18583cf0dc9SAart Bik     valHeuristic = lvlSizesValues[0];
18683cf0dc9SAart Bik     for (Level lvl = 1; lvl < lvlRank; lvl++)
18783cf0dc9SAart Bik       valHeuristic =
18883cf0dc9SAart Bik           builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
189e2e6e7a6SAart Bik   } else if (sizeHint) {
1905248a987SPeiming Liu     if (stt.getAoSCOOStart() == 0) {
19184cd51bbSwren romano       posHeuristic = constantIndex(builder, loc, 2);
19284cd51bbSwren romano       crdHeuristic = builder.create<arith::MulIOp>(
19383cf0dc9SAart Bik           loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
19483cf0dc9SAart Bik     } else if (lvlRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) {
19584cd51bbSwren romano       posHeuristic = builder.create<arith::AddIOp>(
1963bd82f30SAart Bik           loc, sizeHint, constantIndex(builder, loc, 1));
19784cd51bbSwren romano       crdHeuristic = sizeHint;
198e2e6e7a6SAart Bik     } else {
19984cd51bbSwren romano       posHeuristic = crdHeuristic = constantIndex(builder, loc, 16);
200b78b5473Sbixia1     }
201e2e6e7a6SAart Bik     valHeuristic = sizeHint;
202e2e6e7a6SAart Bik   } else {
20384cd51bbSwren romano     posHeuristic = crdHeuristic = valHeuristic =
204e2e6e7a6SAart Bik         constantIndex(builder, loc, 16);
205e2e6e7a6SAart Bik   }
206160d483bSAart Bik   // Initializes all fields. An initial storage specifier and allocated
207160d483bSAart Bik   // positions/coordinates/values memrefs (with heuristic capacity).
208191c43f6SPeiming Liu   foreachFieldAndTypeInSparseTensor(
209f708a549Swren romano       stt,
21084cd51bbSwren romano       [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
211f708a549Swren romano        enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
2121944c4f7SAart Bik                    Level /*lvl*/, LevelType /*lt*/) -> bool {
213191c43f6SPeiming Liu         assert(fields.size() == fIdx);
214191c43f6SPeiming Liu         Value field;
215191c43f6SPeiming Liu         switch (fKind) {
216988733c6SPeiming Liu         case SparseTensorFieldKind::StorageSpec:
217f708a549Swren romano           field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
218191c43f6SPeiming Liu           break;
21984cd51bbSwren romano         case SparseTensorFieldKind::PosMemRef:
220160d483bSAart Bik           field = createAllocation(builder, loc, cast<MemRefType>(fType),
221160d483bSAart Bik                                    posHeuristic, enableInit);
222160d483bSAart Bik           break;
22384cd51bbSwren romano         case SparseTensorFieldKind::CrdMemRef:
224160d483bSAart Bik           field = createAllocation(builder, loc, cast<MemRefType>(fType),
225160d483bSAart Bik                                    crdHeuristic, enableInit);
226160d483bSAart Bik           break;
227191c43f6SPeiming Liu         case SparseTensorFieldKind::ValMemRef:
228160d483bSAart Bik           field = createAllocation(builder, loc, cast<MemRefType>(fType),
229160d483bSAart Bik                                    valHeuristic, enableInit);
230191c43f6SPeiming Liu           break;
231191c43f6SPeiming Liu         }
232191c43f6SPeiming Liu         assert(field);
233191c43f6SPeiming Liu         fields.push_back(field);
234191c43f6SPeiming Liu         // Returns true to continue the iteration.
235191c43f6SPeiming Liu         return true;
236191c43f6SPeiming Liu       });
237160d483bSAart Bik   // Initialize the storage scheme to an empty tensor. Sets the lvlSizes
238160d483bSAart Bik   // and gives all position fields an initial zero entry, so that it is
239160d483bSAart Bik   // easier to maintain the "linear + 1" length property.
240f708a549Swren romano   MutSparseTensorDescriptor desc(stt, fields);
24184cd51bbSwren romano   Value posZero = constantZero(builder, loc, stt.getPosType());
242160d483bSAart Bik   for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
24383cf0dc9SAart Bik     desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
2441dd387e1SAart Bik     const auto lt = stt.getLvlType(lvl);
2451dd387e1SAart Bik     if (isCompressedLT(lt) || isLooseCompressedLT(lt))
246160d483bSAart Bik       createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
247160d483bSAart Bik                      /*value=*/posZero);
248988733c6SPeiming Liu   }
249191c43f6SPeiming Liu   allocSchemeForRank(builder, loc, desc, /*rank=*/0);
2500c7abd39SAart Bik }
2510c7abd39SAart Bik 
25270633a8dSAart Bik /// Helper method that generates block specific to compressed case:
25370633a8dSAart Bik ///
25484cd51bbSwren romano ///  // given: parentPos = posCursor[lvl-1]
25584cd51bbSwren romano ///  pstart = desc.positions[lvl][parentPos]
25684cd51bbSwren romano ///  pstop = desc.positions[lvl][parentPos+1]
25784cd51bbSwren romano ///  plast = pstop - 1
25884cd51bbSwren romano ///  msz = desc.coordinates[lvl].size()
25984cd51bbSwren romano ///  if (pstart < pstop) {
26084cd51bbSwren romano ///    isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl])
26170633a8dSAart Bik ///  } else { // first insertion
26284cd51bbSwren romano ///    isPresent = false
26384cd51bbSwren romano ///    desc.positions[lvl][parentPos] = msz
26470633a8dSAart Bik ///  }
26584cd51bbSwren romano ///  if (isPresent) { // coordinate is already present
26684cd51bbSwren romano ///    pnext = plast
26770633a8dSAart Bik ///  } else {
26884cd51bbSwren romano ///    desc.coordinates[lvl].push_back(lvlCoords[lvl])
26984cd51bbSwren romano ///    desc.positions[lvl][parentPos+1] = msz+1
27084cd51bbSwren romano ///    pnext = msz
27184cd51bbSwren romano ///    <prepare level lvl+1>
27270633a8dSAart Bik ///  }
27384cd51bbSwren romano ///  posCursor[lvl] = pnext
27470633a8dSAart Bik static Value genCompressed(OpBuilder &builder, Location loc,
27584cd51bbSwren romano                            MutSparseTensorDescriptor desc, ValueRange lvlCoords,
27684cd51bbSwren romano                            Value /*unused*/, Value parentPos, Level lvl) {
277f708a549Swren romano   const SparseTensorType stt(desc.getRankedTensorType());
278f708a549Swren romano   const Level lvlRank = stt.getLvlRank();
279f708a549Swren romano   assert(lvl < lvlRank && "Level is out of bounds");
28084cd51bbSwren romano   assert(lvlCoords.size() == static_cast<size_t>(lvlRank) &&
281f708a549Swren romano          "Level-rank mismatch");
2820e1708ffSAart Bik   SmallVector<Type> types;
2833986c869SAart Bik   Type indexType = builder.getIndexType();
28470633a8dSAart Bik   Type boolType = builder.getIntegerType(1);
28584cd51bbSwren romano   unsigned crdFidx;
28684cd51bbSwren romano   unsigned crdStride;
28784cd51bbSwren romano   std::tie(crdFidx, crdStride) = desc.getCrdMemRefIndexAndStride(lvl);
28884cd51bbSwren romano   const Value one = constantIndex(builder, loc, 1);
28984cd51bbSwren romano   const Value pp1 = builder.create<arith::AddIOp>(loc, parentPos, one);
29084cd51bbSwren romano   const Value positionsAtLvl = desc.getPosMemRef(lvl);
29184cd51bbSwren romano   const Value pstart = genLoad(builder, loc, positionsAtLvl, parentPos);
29284cd51bbSwren romano   const Value pstop = genLoad(builder, loc, positionsAtLvl, pp1);
29384cd51bbSwren romano   const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl);
29484cd51bbSwren romano   const Value crdStrideC =
29584cd51bbSwren romano       crdStride > 1 ? constantIndex(builder, loc, crdStride) : Value();
29684cd51bbSwren romano   const Value msz =
29784cd51bbSwren romano       crdStrideC ? builder.create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
29884cd51bbSwren romano                  : crdMsz;
29984cd51bbSwren romano   const Value plast = builder.create<arith::SubIOp>(
30084cd51bbSwren romano       loc, genCast(builder, loc, pstop, indexType), one);
30170633a8dSAart Bik   // Conditional expression.
30284cd51bbSwren romano   Value lt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
30384cd51bbSwren romano                                            pstart, pstop);
30470633a8dSAart Bik   types.push_back(boolType);
30570633a8dSAart Bik   scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
30670633a8dSAart Bik   types.pop_back();
30770633a8dSAart Bik   builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
30884cd51bbSwren romano   Value crd =
30984cd51bbSwren romano       genLoad(builder, loc, desc.getMemRefField(crdFidx),
31084cd51bbSwren romano               crdStrideC ? builder.create<arith::MulIOp>(loc, plast, crdStrideC)
31184cd51bbSwren romano                          : plast);
31244ff23d5SPeiming Liu   Value eq = builder.create<arith::CmpIOp>(
31344ff23d5SPeiming Liu       loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType),
31484cd51bbSwren romano       lvlCoords[lvl]);
31570633a8dSAart Bik   builder.create<scf::YieldOp>(loc, eq);
31670633a8dSAart Bik   builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
317f708a549Swren romano   if (lvl > 0)
31884cd51bbSwren romano     genStore(builder, loc, msz, positionsAtLvl, parentPos);
31970633a8dSAart Bik   builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
32070633a8dSAart Bik   builder.setInsertionPointAfter(ifOp1);
321191c43f6SPeiming Liu   // If present construct. Note that for a non-unique dimension level, we
322191c43f6SPeiming Liu   // simply set the condition to false and rely on CSE/DCE to clean up the IR.
32370633a8dSAart Bik   //
32470633a8dSAart Bik   // TODO: generate less temporary IR?
32570633a8dSAart Bik   //
326191c43f6SPeiming Liu   for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
327191c43f6SPeiming Liu     types.push_back(desc.getField(i).getType());
32870633a8dSAart Bik   types.push_back(indexType);
329f708a549Swren romano   const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0)
330f708a549Swren romano                                        : constantI1(builder, loc, false);
33170633a8dSAart Bik   scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true);
33284cd51bbSwren romano   // If present (fields unaffected, update pnext to plast).
33370633a8dSAart Bik   builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
334191c43f6SPeiming Liu 
335191c43f6SPeiming Liu   // FIXME: This does not looks like a clean way, but probably the most
336191c43f6SPeiming Liu   // efficient way.
33784cd51bbSwren romano   desc.getFields().push_back(plast);
338191c43f6SPeiming Liu   builder.create<scf::YieldOp>(loc, desc.getFields());
339191c43f6SPeiming Liu   desc.getFields().pop_back();
340191c43f6SPeiming Liu 
34184cd51bbSwren romano   // If !present (changes fields, update pnext).
34270633a8dSAart Bik   builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
34370633a8dSAart Bik   Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
34484cd51bbSwren romano   genStore(builder, loc, mszp1, positionsAtLvl, pp1);
34584cd51bbSwren romano   createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl,
346160d483bSAart Bik                  /*value=*/lvlCoords[lvl]);
34784cd51bbSwren romano   // Prepare the next level "as needed".
348f708a549Swren romano   if ((lvl + 1) < lvlRank)
349f708a549Swren romano     allocSchemeForRank(builder, loc, desc, lvl + 1);
350191c43f6SPeiming Liu 
351191c43f6SPeiming Liu   desc.getFields().push_back(msz);
352191c43f6SPeiming Liu   builder.create<scf::YieldOp>(loc, desc.getFields());
353191c43f6SPeiming Liu   desc.getFields().pop_back();
354191c43f6SPeiming Liu 
35570633a8dSAart Bik   // Update fields and return next pos.
35670633a8dSAart Bik   builder.setInsertionPointAfter(ifOp2);
35770633a8dSAart Bik   unsigned o = 0;
358191c43f6SPeiming Liu   for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
359191c43f6SPeiming Liu     desc.setField(i, ifOp2.getResult(o++));
36070633a8dSAart Bik   return ifOp2.getResult(o);
3613986c869SAart Bik }
3623986c869SAart Bik 
363bc61122aSAart Bik /// Generates insertion finalization code.
364bc61122aSAart Bik static void genEndInsert(OpBuilder &builder, Location loc,
365bc61122aSAart Bik                          SparseTensorDescriptor desc) {
366bc61122aSAart Bik   const SparseTensorType stt(desc.getRankedTensorType());
367bc61122aSAart Bik   const Level lvlRank = stt.getLvlRank();
36883cf0dc9SAart Bik   for (Level lvl = 0; lvl < lvlRank; lvl++) {
3691dd387e1SAart Bik     const auto lt = stt.getLvlType(lvl);
3701dd387e1SAart Bik     if (isCompressedLT(lt)) {
371bc61122aSAart Bik       // Compressed dimensions need a position cleanup for all entries
372bc61122aSAart Bik       // that were not visited during the insertion pass.
373bc61122aSAart Bik       //
374bc61122aSAart Bik       // TODO: avoid cleanup and keep compressed scheme consistent at all
375bc61122aSAart Bik       // times?
376bc61122aSAart Bik       //
37783cf0dc9SAart Bik       if (lvl > 0) {
378bc61122aSAart Bik         Type posType = stt.getPosType();
37983cf0dc9SAart Bik         Value posMemRef = desc.getPosMemRef(lvl);
38083cf0dc9SAart Bik         Value hi = desc.getPosMemSize(builder, loc, lvl);
381bc61122aSAart Bik         Value zero = constantIndex(builder, loc, 0);
382bc61122aSAart Bik         Value one = constantIndex(builder, loc, 1);
383bc61122aSAart Bik         // Vector of only one, but needed by createFor's prototype.
384bc61122aSAart Bik         SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)};
385bc61122aSAart Bik         scf::ForOp loop = createFor(builder, loc, hi, inits, one);
386bc61122aSAart Bik         Value i = loop.getInductionVar();
387bc61122aSAart Bik         Value oldv = loop.getRegionIterArg(0);
388bc61122aSAart Bik         Value newv = genLoad(builder, loc, posMemRef, i);
389bc61122aSAart Bik         Value posZero = constantZero(builder, loc, posType);
390bc61122aSAart Bik         Value cond = builder.create<arith::CmpIOp>(
391bc61122aSAart Bik             loc, arith::CmpIPredicate::eq, newv, posZero);
392bc61122aSAart Bik         scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType),
393bc61122aSAart Bik                                                    cond, /*else*/ true);
394bc61122aSAart Bik         builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
395bc61122aSAart Bik         genStore(builder, loc, oldv, posMemRef, i);
396bc61122aSAart Bik         builder.create<scf::YieldOp>(loc, oldv);
397bc61122aSAart Bik         builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
398bc61122aSAart Bik         builder.create<scf::YieldOp>(loc, newv);
399bc61122aSAart Bik         builder.setInsertionPointAfter(ifOp);
400bc61122aSAart Bik         builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
401bc61122aSAart Bik         builder.setInsertionPointAfter(loop);
402bc61122aSAart Bik       }
403bc61122aSAart Bik     } else {
4041dd387e1SAart Bik       assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
405e5924d64SYinying Li              isNOutOfMLT(lt));
406bc61122aSAart Bik     }
407bc61122aSAart Bik   }
408bc61122aSAart Bik }
409bc61122aSAart Bik 
410bc61122aSAart Bik /// Generates a subview into the sizes.
411bc61122aSAart Bik static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
412bc61122aSAart Bik                             Value sz) {
413c4e5a8a4SAart Bik   auto memTp = llvm::cast<MemRefType>(mem.getType());
414c4e5a8a4SAart Bik   // For higher-dimensional memrefs, we assume that the innermost
415c4e5a8a4SAart Bik   // dimension is always of the right size.
416c4e5a8a4SAart Bik   // TODO: generate complex truncating view here too?
417c4e5a8a4SAart Bik   if (memTp.getRank() > 1)
418c4e5a8a4SAart Bik     return mem;
419c4e5a8a4SAart Bik   // Truncate linear memrefs to given size.
420bc61122aSAart Bik   return builder
421bc61122aSAart Bik       .create<memref::SubViewOp>(
422c4e5a8a4SAart Bik           loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
423c4e5a8a4SAart Bik           mem, ValueRange{}, ValueRange{sz}, ValueRange{},
424bc61122aSAart Bik           ArrayRef<int64_t>{0},                    // static offset
425bc61122aSAart Bik           ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
426bc61122aSAart Bik           ArrayRef<int64_t>{1})                    // static stride
427bc61122aSAart Bik       .getResult();
428bc61122aSAart Bik }
429bc61122aSAart Bik 
430bc61122aSAart Bik /// Creates the reassociation array.
43152b69aa3SPeiming Liu static SmallVector<ReassociationIndices>
43252b69aa3SPeiming Liu getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
43352b69aa3SPeiming Liu   SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
43452b69aa3SPeiming Liu   // Create reassociation in the form:
43552b69aa3SPeiming Liu   // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
43652b69aa3SPeiming Liu   for (unsigned i = 0; i < batchLvls; i++)
43752b69aa3SPeiming Liu     ret[i].push_back(i);
43852b69aa3SPeiming Liu 
43952b69aa3SPeiming Liu   for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
44052b69aa3SPeiming Liu     ret.back().push_back(i);
44152b69aa3SPeiming Liu 
44252b69aa3SPeiming Liu   return ret;
443bc61122aSAart Bik }
444bc61122aSAart Bik 
445bc61122aSAart Bik //===----------------------------------------------------------------------===//
446bc61122aSAart Bik // Codegen rules.
447bc61122aSAart Bik //===----------------------------------------------------------------------===//
448bc61122aSAart Bik 
449bc61122aSAart Bik namespace {
450bc61122aSAart Bik 
451ad469385SPeiming Liu /// Helper class to help lowering sparse_tensor.insert operation.
452ad469385SPeiming Liu class SparseInsertGenerator
453ad469385SPeiming Liu     : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
454ad469385SPeiming Liu public:
455ad469385SPeiming Liu   SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
456ad469385SPeiming Liu                         bool genCall)
457ad469385SPeiming Liu       : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){};
458ad469385SPeiming Liu 
45970633a8dSAart Bik   /// Generates code along an insertion path without the need for a "cursor".
46070633a8dSAart Bik   /// This current insertion strategy comes at the expense of some testing
46170633a8dSAart Bik   /// overhead for each insertion. The strategy will be optimized later for
46270633a8dSAart Bik   /// common insertion patterns. The current insertion strategy also assumes
46370633a8dSAart Bik   /// insertions occur in "a reasonable order" that enables building the
46470633a8dSAart Bik   /// storage scheme in an appending/inserting kind of fashion (i.e. no
46570633a8dSAart Bik   /// in-between insertions that need data movement). The implementation
46670633a8dSAart Bik   /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
46770633a8dSAart Bik   ///
46870633a8dSAart Bik   /// TODO: better unord/not-unique; also generalize, optimize, specialize!
469ad469385SPeiming Liu   SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
470ad469385SPeiming Liu                                        OpBuilder &builder, Location loc) {
47168f58812STres Popp     const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
472f708a549Swren romano     const Level lvlRank = stt.getLvlRank();
47384cd51bbSwren romano     // Extract fields and coordinates from args.
474f708a549Swren romano     SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
475ad469385SPeiming Liu     MutSparseTensorDescriptor desc(stt, fields);
476962484aeSwren romano     const SmallVector<Value> coords =
477f708a549Swren romano         llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
4782aceaddaSbixia1     Value value = args.back();
47984cd51bbSwren romano     Value parentPos = constantZero(builder, loc, builder.getIndexType());
480f708a549Swren romano     // Generate code for every level.
481160d483bSAart Bik     for (Level lvl = 0; lvl < lvlRank; lvl++) {
4821dd387e1SAart Bik       const auto lt = stt.getLvlType(lvl);
4831dd387e1SAart Bik       if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
48470633a8dSAart Bik         // Create:
48570633a8dSAart Bik         //   if (!present) {
486160d483bSAart Bik         //     coordinates[lvl].push_back(coords[lvl])
487160d483bSAart Bik         //     <update positions and prepare level lvl + 1>
48870633a8dSAart Bik         //   }
489160d483bSAart Bik         //   positions[lvl] = coordinates.size() - 1
490160d483bSAart Bik         //   <insert @ positions[lvl] at next level lvl + 1>
4911dd387e1SAart Bik         if (isLooseCompressedLT(lt)) {
492160d483bSAart Bik           Value two = constantIndex(builder, loc, 2);
493160d483bSAart Bik           parentPos = builder.create<arith::MulIOp>(loc, parentPos, two);
494160d483bSAart Bik         }
49584cd51bbSwren romano         parentPos =
496160d483bSAart Bik             genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
497e5924d64SYinying Li       } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
49870633a8dSAart Bik         // Create:
499160d483bSAart Bik         //   coordinates[lvl].push_back(coords[lvl])
500160d483bSAart Bik         //   positions[lvl] = positions[lvl-1]
501160d483bSAart Bik         //   <insert @ positions[lvl] at next level lvl + 1>
502160d483bSAart Bik         createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef,
503160d483bSAart Bik                        lvl, /*value=*/coords[lvl]);
50470633a8dSAart Bik       } else {
5051dd387e1SAart Bik         assert(isDenseLT(lt));
50670633a8dSAart Bik         // Construct the new position as:
507160d483bSAart Bik         //   positions[lvl] = size * positions[lvl-1] + coords[lvl]
508160d483bSAart Bik         //   <insert @ positions[lvl] at next level lvl + 1>
509160d483bSAart Bik         Value size = desc.getLvlSize(builder, loc, lvl);
51084cd51bbSwren romano         Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
511160d483bSAart Bik         parentPos = builder.create<arith::AddIOp>(loc, mult, coords[lvl]);
51270633a8dSAart Bik       }
51370633a8dSAart Bik     }
51470633a8dSAart Bik     // Reached the actual value append/insert.
515f708a549Swren romano     if (!stt.isDenseLvl(lvlRank - 1))
516988733c6SPeiming Liu       createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
517988733c6SPeiming Liu                      std::nullopt, value);
51870633a8dSAart Bik     else
51984cd51bbSwren romano       genStore(builder, loc, value, desc.getValMemRef(), parentPos);
520ad469385SPeiming Liu     return fields;
5212aceaddaSbixia1   }
5222aceaddaSbixia1 
523ad469385SPeiming Liu   std::string getMangledFuncName() {
5242aceaddaSbixia1     // The mangled name of the function has this format:
5251dd387e1SAart Bik     //   <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
526ad469385SPeiming Liu     constexpr const char kInsertFuncNamePrefix[] = "_insert_";
52768f58812STres Popp     const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
5282aceaddaSbixia1     SmallString<32> nameBuffer;
5292aceaddaSbixia1     llvm::raw_svector_ostream nameOstream(nameBuffer);
530ad469385SPeiming Liu     nameOstream << kInsertFuncNamePrefix;
531f708a549Swren romano     const Level lvlRank = stt.getLvlRank();
5326280e231SYinying Li     for (Level l = 0; l < lvlRank; l++) {
5336280e231SYinying Li       std::string lvlType = toMLIRString(stt.getLvlType(l));
5346280e231SYinying Li       // Replace/remove punctuations in level properties.
5356280e231SYinying Li       std::replace_if(
5366280e231SYinying Li           lvlType.begin(), lvlType.end(),
5376280e231SYinying Li           [](char c) { return c == '(' || c == ','; }, '_');
5386461a824SKazu Hirata       llvm::erase_if(lvlType, [](char c) { return c == ')' || c == ' '; });
5396280e231SYinying Li       nameOstream << lvlType << "_";
5406280e231SYinying Li     }
5412aceaddaSbixia1     // Static dim sizes are used in the generated code while dynamic sizes are
5422aceaddaSbixia1     // loaded from the dimSizes buffer. This is the reason for adding the shape
5432aceaddaSbixia1     // to the function name.
544160d483bSAart Bik     for (const auto sz : stt.getDimShape())
545160d483bSAart Bik       nameOstream << sz << "_";
5462aceaddaSbixia1     // Permutation information is also used in generating insertion.
547f708a549Swren romano     if (!stt.isIdentity())
54876647fceSwren romano       nameOstream << stt.getDimToLvl() << "_";
549f708a549Swren romano     nameOstream << stt.getElementType() << "_";
55084cd51bbSwren romano     nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
551ad469385SPeiming Liu     return nameOstream.str().str();
5522aceaddaSbixia1   }
5532aceaddaSbixia1 
554ad469385SPeiming Liu private:
555ad469385SPeiming Liu   TensorType rtp;
556ad469385SPeiming Liu };
5579f596a7cSAart Bik 
558edca72f5SPeiming Liu /// Sparse tensor storage conversion rule for returns.
55986b22d31SAart Bik class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
56086b22d31SAart Bik public:
56186b22d31SAart Bik   using OpConversionPattern::OpConversionPattern;
56286b22d31SAart Bik   LogicalResult
563*9df63b26SMatthias Springer   matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
56486b22d31SAart Bik                   ConversionPatternRewriter &rewriter) const override {
565edca72f5SPeiming Liu     // Create a return with the flattened value extracted from sparse tensors.
566*9df63b26SMatthias Springer     rewriter.replaceOpWithNewOp<func::ReturnOp>(
567*9df63b26SMatthias Springer         op, flattenValues(adaptor.getOperands()));
568edca72f5SPeiming Liu     return success();
569edca72f5SPeiming Liu   }
570edca72f5SPeiming Liu };
571edca72f5SPeiming Liu 
572edca72f5SPeiming Liu /// Sparse tensor storage conversion rule for calls.
573edca72f5SPeiming Liu class SparseCallConverter : public OpConversionPattern<func::CallOp> {
574edca72f5SPeiming Liu public:
575edca72f5SPeiming Liu   // The default CallOp converter can not handle 1:N type conversion.
576edca72f5SPeiming Liu   using OpConversionPattern::OpConversionPattern;
577edca72f5SPeiming Liu   LogicalResult
578*9df63b26SMatthias Springer   matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
579edca72f5SPeiming Liu                   ConversionPatternRewriter &rewriter) const override {
580edca72f5SPeiming Liu     Location loc = op.getLoc();
581edca72f5SPeiming Liu     // In case of:
582edca72f5SPeiming Liu     //  sparse_tensor, f, sparse_tensor = call @foo(...)
583edca72f5SPeiming Liu     // ==>
584edca72f5SPeiming Liu     //  memref..., f, memref = call @foo(...) replace with
585edca72f5SPeiming Liu     //  cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
5860e1708ffSAart Bik     SmallVector<Type> finalRetTy;
587edca72f5SPeiming Liu     if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
588edca72f5SPeiming Liu       return failure();
589edca72f5SPeiming Liu 
590be556ee1SYinying Li     // (1) Generates new call with flattened return value.
591*9df63b26SMatthias Springer     auto newCall = rewriter.create<func::CallOp>(
592*9df63b26SMatthias Springer         loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands()));
593aed43562SMatthias Springer     // (2) Gather sparse tensor returns.
594aed43562SMatthias Springer     SmallVector<SmallVector<Value>> packedResultVals;
595be556ee1SYinying Li     // Tracks the offset of current return value (of the original call)
596edca72f5SPeiming Liu     // relative to the new call (after sparse tensor flattening);
597edca72f5SPeiming Liu     unsigned retOffset = 0;
598edca72f5SPeiming Liu     // Temporal buffer to hold the flattened list of type for
599edca72f5SPeiming Liu     // a sparse tensor.
6000e1708ffSAart Bik     SmallVector<Type> sparseFlat;
601edca72f5SPeiming Liu     for (auto ret : op.getResults()) {
602edca72f5SPeiming Liu       assert(retOffset < newCall.getNumResults());
603edca72f5SPeiming Liu       auto retType = ret.getType();
604edca72f5SPeiming Liu       if (failed(typeConverter->convertType(retType, sparseFlat)))
605edca72f5SPeiming Liu         llvm_unreachable("Failed to convert type in sparse tensor codegen");
606edca72f5SPeiming Liu 
607edca72f5SPeiming Liu       // Converted types can not be empty when the type conversion succeed.
608edca72f5SPeiming Liu       assert(!sparseFlat.empty());
609edca72f5SPeiming Liu       if (sparseFlat.size() > 1) {
610edca72f5SPeiming Liu         auto flatSize = sparseFlat.size();
611aed43562SMatthias Springer         packedResultVals.emplace_back();
612aed43562SMatthias Springer         llvm::append_range(packedResultVals.back(),
613aed43562SMatthias Springer                            newCall.getResults().slice(retOffset, flatSize));
614edca72f5SPeiming Liu         retOffset += flatSize;
615edca72f5SPeiming Liu       } else {
616edca72f5SPeiming Liu         // If this is an 1:1 conversion, no need for casting.
617aed43562SMatthias Springer         packedResultVals.emplace_back();
618aed43562SMatthias Springer         packedResultVals.back().push_back(newCall.getResult(retOffset));
619edca72f5SPeiming Liu         retOffset++;
620edca72f5SPeiming Liu       }
621edca72f5SPeiming Liu       sparseFlat.clear();
622edca72f5SPeiming Liu     }
623edca72f5SPeiming Liu 
624aed43562SMatthias Springer     assert(packedResultVals.size() == op.getNumResults());
625aed43562SMatthias Springer     rewriter.replaceOpWithMultiple(
626aed43562SMatthias Springer         op, llvm::to_vector_of<ValueRange>(packedResultVals));
62786b22d31SAart Bik     return success();
62886b22d31SAart Bik   }
62986b22d31SAart Bik };
63086b22d31SAart Bik 
631c780352dSPeiming Liu /// Sparse codegen rule for level accesses.
632c780352dSPeiming Liu class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
6331be09496SAart Bik public:
6341be09496SAart Bik   using OpConversionPattern::OpConversionPattern;
6351be09496SAart Bik   LogicalResult
636*9df63b26SMatthias Springer   matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
6371be09496SAart Bik                   ConversionPatternRewriter &rewriter) const override {
638c780352dSPeiming Liu     std::optional<int64_t> lvl = op.getConstantLvlIndex();
639204234a6SMatthias Springer     RankedTensorType srcType = op.getSource().getType();
640204234a6SMatthias Springer     if (!lvl || !getSparseTensorEncoding(srcType))
6411be09496SAart Bik       return failure();
642191c43f6SPeiming Liu 
643204234a6SMatthias Springer     auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
644c780352dSPeiming Liu     auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
645191c43f6SPeiming Liu 
64683a50839SPeiming Liu     rewriter.replaceOp(op, sz);
6473ae98fd2SAart Bik     return success();
6483ae98fd2SAart Bik   }
6493ae98fd2SAart Bik };
6503ae98fd2SAart Bik 
651dda3dc5eSPeiming Liu // TODO: use a new SortCOO operation here instead of reusing convert op.
652f248d0b2SPeiming Liu struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
653dda3dc5eSPeiming Liu   using OpConversionPattern::OpConversionPattern;
654dda3dc5eSPeiming Liu   LogicalResult
655*9df63b26SMatthias Springer   matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
656dda3dc5eSPeiming Liu                   ConversionPatternRewriter &rewriter) const override {
657dda3dc5eSPeiming Liu     Location loc = op.getLoc();
658dda3dc5eSPeiming Liu     MLIRContext *ctx = op.getContext();
659dda3dc5eSPeiming Liu 
660f248d0b2SPeiming Liu     SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
661f248d0b2SPeiming Liu     SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
662dda3dc5eSPeiming Liu 
663f248d0b2SPeiming Liu     // Should have been verified.
664dda3dc5eSPeiming Liu     assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
6655b729503SAart Bik            dstStt.isCOOType() && srcStt.isCOOType());
666dda3dc5eSPeiming Liu     assert(dstStt.hasSameDimToLvl(srcStt));
667dda3dc5eSPeiming Liu 
668dda3dc5eSPeiming Liu     // We don't need a mutable descriptor here as we perform sorting in-place.
669204234a6SMatthias Springer     auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
670204234a6SMatthias Springer                                              op.getInputCoo().getType());
671204234a6SMatthias Springer     auto nnz = desc.getValMemSize(rewriter, op.getLoc());
672dda3dc5eSPeiming Liu     auto crd = desc.getAOSMemRef();
673dda3dc5eSPeiming Liu     auto val = desc.getValMemRef();
674dda3dc5eSPeiming Liu 
675dda3dc5eSPeiming Liu     // Otherwise we need another data shuffle and a non-identity map.
676dda3dc5eSPeiming Liu     assert(dstStt.hasSameDimToLvl(srcStt));
677837a26f2SPeiming Liu     (void)dstStt; // to silence warning when assertion is disabled
678837a26f2SPeiming Liu 
679dda3dc5eSPeiming Liu     auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
680dda3dc5eSPeiming Liu 
681dda3dc5eSPeiming Liu     rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
682f248d0b2SPeiming Liu                             rewriter.getIndexAttr(0), op.getAlgorithm());
683dda3dc5eSPeiming Liu 
684dda3dc5eSPeiming Liu     // Since we do in-place sorting, the destinate tensor will have the same set
685dda3dc5eSPeiming Liu     // of memrefs as the source tensor.
686*9df63b26SMatthias Springer     rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
687dda3dc5eSPeiming Liu     return success();
688dda3dc5eSPeiming Liu   }
689dda3dc5eSPeiming Liu };
690dda3dc5eSPeiming Liu 
6916db397a8SPeiming Liu template <typename Op, StorageSpecifierKind kind>
6926db397a8SPeiming Liu class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
6936db397a8SPeiming Liu public:
6946db397a8SPeiming Liu   using OpConversionPattern<Op>::OpConversionPattern;
695*9df63b26SMatthias Springer   using typename OpConversionPattern<Op>::OneToNOpAdaptor;
696*9df63b26SMatthias Springer 
6976db397a8SPeiming Liu   LogicalResult
698*9df63b26SMatthias Springer   matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
6996db397a8SPeiming Liu                   ConversionPatternRewriter &rewriter) const override {
7006db397a8SPeiming Liu     // Simply lowers to specifer.get <field> operation.
701204234a6SMatthias Springer     auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
702204234a6SMatthias Springer                                              op.getSlice().getType());
7036db397a8SPeiming Liu     auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
7046db397a8SPeiming Liu                                     op.getDim().getZExtValue());
7056db397a8SPeiming Liu 
7066db397a8SPeiming Liu     rewriter.replaceOp(op, v);
7076db397a8SPeiming Liu     return success();
7086db397a8SPeiming Liu   }
7096db397a8SPeiming Liu };
7106db397a8SPeiming Liu 
711f27b806dSAart Bik /// Sparse codegen rule for trivial tensor casts.
712f27b806dSAart Bik class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
713f27b806dSAart Bik public:
714f27b806dSAart Bik   using OpConversionPattern::OpConversionPattern;
715f27b806dSAart Bik   LogicalResult
716*9df63b26SMatthias Springer   matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
717f27b806dSAart Bik                   ConversionPatternRewriter &rewriter) const override {
718f27b806dSAart Bik     // Only rewrite identically annotated source/dest.
719f27b806dSAart Bik     auto encDst = getSparseTensorEncoding(op.getType());
720f27b806dSAart Bik     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
721f27b806dSAart Bik     if (!encDst || encDst != encSrc)
722f27b806dSAart Bik       return failure();
723*9df63b26SMatthias Springer     rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
724f27b806dSAart Bik     return success();
725f27b806dSAart Bik   }
726f27b806dSAart Bik };
727f27b806dSAart Bik 
728ef222988SPeiming Liu class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
729ef222988SPeiming Liu public:
730ef222988SPeiming Liu   using OpConversionPattern::OpConversionPattern;
731ef222988SPeiming Liu   LogicalResult
732*9df63b26SMatthias Springer   matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
733ef222988SPeiming Liu                   ConversionPatternRewriter &rewriter) const override {
734ef222988SPeiming Liu     // Simply fold the operation.
735*9df63b26SMatthias Springer     rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
736ef222988SPeiming Liu     return success();
737ef222988SPeiming Liu   }
738ef222988SPeiming Liu };
739ef222988SPeiming Liu 
740be556ee1SYinying Li /// Sparse codegen rule for the alloc operator.
7410c7abd39SAart Bik class SparseTensorAllocConverter
7420c7abd39SAart Bik     : public OpConversionPattern<bufferization::AllocTensorOp> {
7430c7abd39SAart Bik public:
7440c7abd39SAart Bik   using OpConversionPattern::OpConversionPattern;
745206fad0eSMatthias Springer   SparseTensorAllocConverter(const TypeConverter &typeConverter,
746206fad0eSMatthias Springer                              MLIRContext *context, bool enableInit)
7477276b643Sbixia1       : OpConversionPattern(typeConverter, context),
7487276b643Sbixia1         enableBufferInitialization(enableInit) {}
749988733c6SPeiming Liu 
7500c7abd39SAart Bik   LogicalResult
751*9df63b26SMatthias Springer   matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
7520c7abd39SAart Bik                   ConversionPatternRewriter &rewriter) const override {
753f708a549Swren romano     const auto resType = getSparseTensorType(op);
754f708a549Swren romano     if (!resType.hasEncoding())
7550c7abd39SAart Bik       return failure();
75683cf0dc9SAart Bik 
7572cc4b3d0SPeiming Liu     Location loc = op.getLoc();
75883cf0dc9SAart Bik     // Deal with copy.
7597b86f7c5SPeiming Liu     if (op.getCopy()) {
760204234a6SMatthias Springer       auto desc = getDescriptorFromTensorTuple(
761204234a6SMatthias Springer           adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
7627b86f7c5SPeiming Liu       SmallVector<Value> fields;
7637b86f7c5SPeiming Liu       fields.reserve(desc.getNumFields());
7647b86f7c5SPeiming Liu       // Memcpy on memref fields.
7657b86f7c5SPeiming Liu       for (auto field : desc.getMemRefFields()) {
7665550c821STres Popp         auto memrefTp = cast<MemRefType>(field.getType());
7677b86f7c5SPeiming Liu         auto size = rewriter.create<memref::DimOp>(loc, field, 0);
7687b86f7c5SPeiming Liu         auto copied =
7697b86f7c5SPeiming Liu             rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size});
7707b86f7c5SPeiming Liu         rewriter.create<memref::CopyOp>(loc, field, copied);
7717b86f7c5SPeiming Liu         fields.push_back(copied);
7727b86f7c5SPeiming Liu       }
7737b86f7c5SPeiming Liu       // Reuses specifier.
7747b86f7c5SPeiming Liu       fields.push_back(desc.getSpecifier());
7757b86f7c5SPeiming Liu       assert(fields.size() == desc.getNumFields());
776aed43562SMatthias Springer       rewriter.replaceOpWithMultiple(op, {fields});
7777b86f7c5SPeiming Liu       return success();
7787b86f7c5SPeiming Liu     }
7797b86f7c5SPeiming Liu 
7802cc4b3d0SPeiming Liu     if (!resType.isIdentity()) {
7812cc4b3d0SPeiming Liu       return rewriter.notifyMatchFailure(
7822cc4b3d0SPeiming Liu           op, "try run --sparse-reinterpret-map before codegen");
7832cc4b3d0SPeiming Liu     }
7842cc4b3d0SPeiming Liu     // Level size equals to dimension size since lvl2dim map is an identity map.
78583cf0dc9SAart Bik     SmallVector<Value> lvlSizesValues;
786*9df63b26SMatthias Springer     createDimSizes(rewriter, loc, resType,
787*9df63b26SMatthias Springer                    flattenValues(adaptor.getDynamicSizes()),
7882cc4b3d0SPeiming Liu                    /*dimSizesValues=*/lvlSizesValues);
78983cf0dc9SAart Bik 
790160d483bSAart Bik     // Construct allocation for each field.
791160d483bSAart Bik     Value sizeHint = op.getSizeHint();
7920e1708ffSAart Bik     SmallVector<Value> fields;
79383cf0dc9SAart Bik     createAllocFields(rewriter, loc, resType, enableBufferInitialization,
79483cf0dc9SAart Bik                       sizeHint, lvlSizesValues, fields);
795160d483bSAart Bik 
796d22df0ebSAart Bik     // Replace operation with resulting memrefs.
797aed43562SMatthias Springer     rewriter.replaceOpWithMultiple(op, {fields});
7980c7abd39SAart Bik     return success();
7990c7abd39SAart Bik   }
8007276b643Sbixia1 
8017276b643Sbixia1 private:
8027276b643Sbixia1   bool enableBufferInitialization;
8030c7abd39SAart Bik };
8040c7abd39SAart Bik 
8053e4a8c2cSAart Bik /// Sparse codegen rule for the empty tensor operator.
8063e4a8c2cSAart Bik class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
8073e4a8c2cSAart Bik public:
8083e4a8c2cSAart Bik   using OpConversionPattern::OpConversionPattern;
809206fad0eSMatthias Springer   SparseTensorEmptyConverter(const TypeConverter &typeConverter,
810206fad0eSMatthias Springer                              MLIRContext *context, bool enableInit)
8113e4a8c2cSAart Bik       : OpConversionPattern(typeConverter, context),
8123e4a8c2cSAart Bik         enableBufferInitialization(enableInit) {}
8133e4a8c2cSAart Bik 
8143e4a8c2cSAart Bik   LogicalResult
8153e4a8c2cSAart Bik   matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
8163e4a8c2cSAart Bik                   ConversionPatternRewriter &rewriter) const override {
8173e4a8c2cSAart Bik     const auto resType = getSparseTensorType(op);
8183e4a8c2cSAart Bik     if (!resType.hasEncoding())
8193e4a8c2cSAart Bik       return failure();
8202cc4b3d0SPeiming Liu 
8212cc4b3d0SPeiming Liu     if (!resType.isIdentity()) {
8222cc4b3d0SPeiming Liu       return rewriter.notifyMatchFailure(
8232cc4b3d0SPeiming Liu           op, "try run --sparse-reinterpret-map before codegen");
8242cc4b3d0SPeiming Liu     }
8252cc4b3d0SPeiming Liu 
82683cf0dc9SAart Bik     Location loc = op.getLoc();
8272cc4b3d0SPeiming Liu     // Level size equals to dimension size since lvl2dim map is an identity map.
82883cf0dc9SAart Bik     SmallVector<Value> lvlSizesValues;
82983cf0dc9SAart Bik     createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
8302cc4b3d0SPeiming Liu                    /*dimSizesValues=*/lvlSizesValues);
8313e4a8c2cSAart Bik     // Construct allocation for each field.
832160d483bSAart Bik     Value sizeHint; // none
8333e4a8c2cSAart Bik     SmallVector<Value> fields;
83483cf0dc9SAart Bik     createAllocFields(rewriter, loc, resType, enableBufferInitialization,
83583cf0dc9SAart Bik                       sizeHint, lvlSizesValues, fields);
836160d483bSAart Bik 
8373e4a8c2cSAart Bik     // Replace operation with resulting memrefs.
838aed43562SMatthias Springer     rewriter.replaceOpWithMultiple(op, {fields});
8393e4a8c2cSAart Bik     return success();
8403e4a8c2cSAart Bik   }
8413e4a8c2cSAart Bik 
8423e4a8c2cSAart Bik private:
8433e4a8c2cSAart Bik   bool enableBufferInitialization;
8443e4a8c2cSAart Bik };
8453e4a8c2cSAart Bik 
8462ddfacd9SAart Bik /// Sparse codegen rule for the dealloc operator.
8472ddfacd9SAart Bik class SparseTensorDeallocConverter
8482ddfacd9SAart Bik     : public OpConversionPattern<bufferization::DeallocTensorOp> {
8492ddfacd9SAart Bik public:
8502ddfacd9SAart Bik   using OpConversionPattern::OpConversionPattern;
851206fad0eSMatthias Springer   SparseTensorDeallocConverter(const TypeConverter &typeConverter,
852c44d307cSPeiming Liu                                MLIRContext *context, bool createDeallocs)
853c44d307cSPeiming Liu       : OpConversionPattern(typeConverter, context),
854c44d307cSPeiming Liu         createDeallocs(createDeallocs) {}
855c44d307cSPeiming Liu 
8562ddfacd9SAart Bik   LogicalResult
857*9df63b26SMatthias Springer   matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
8582ddfacd9SAart Bik                   ConversionPatternRewriter &rewriter) const override {
8592ddfacd9SAart Bik     auto enc = getSparseTensorEncoding(op.getTensor().getType());
8602ddfacd9SAart Bik     if (!enc)
8612ddfacd9SAart Bik       return failure();
862edca72f5SPeiming Liu 
863c44d307cSPeiming Liu     // If user requests not to deallocate sparse tensors, simply erase the
864c44d307cSPeiming Liu     // operation.
865c44d307cSPeiming Liu     if (createDeallocs) {
866edca72f5SPeiming Liu       // Replace the sparse tensor deallocation with field deallocations.
867edca72f5SPeiming Liu       Location loc = op.getLoc();
868204234a6SMatthias Springer       auto desc = getDescriptorFromTensorTuple(
869204234a6SMatthias Springer           adaptor.getTensor(),
870204234a6SMatthias Springer           cast<RankedTensorType>(op.getTensor().getType()));
871988733c6SPeiming Liu       for (auto input : desc.getMemRefFields())
872edca72f5SPeiming Liu         // Deallocate every buffer used to store the sparse tensor handler.
873edca72f5SPeiming Liu         rewriter.create<memref::DeallocOp>(loc, input);
874c44d307cSPeiming Liu     }
8752ddfacd9SAart Bik     rewriter.eraseOp(op);
8762ddfacd9SAart Bik     return success();
8772ddfacd9SAart Bik   }
878c44d307cSPeiming Liu 
879c44d307cSPeiming Liu private:
880fd2211d8SPeiming Liu   const bool createDeallocs;
8812ddfacd9SAart Bik };
8822ddfacd9SAart Bik 
8830c7abd39SAart Bik /// Sparse codegen rule for tensor rematerialization.
8840c7abd39SAart Bik class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
8850c7abd39SAart Bik public:
8860c7abd39SAart Bik   using OpConversionPattern::OpConversionPattern;
8870c7abd39SAart Bik   LogicalResult
888*9df63b26SMatthias Springer   matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
8890c7abd39SAart Bik                   ConversionPatternRewriter &rewriter) const override {
890191c43f6SPeiming Liu     // Prepare descriptor.
891204234a6SMatthias Springer     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
892204234a6SMatthias Springer                                              op.getTensor().getType());
8939f596a7cSAart Bik     // Generate optional insertion finalization code.
8949f596a7cSAart Bik     if (op.getHasInserts())
895191c43f6SPeiming Liu       genEndInsert(rewriter, op.getLoc(), desc);
896d22df0ebSAart Bik     // Replace operation with resulting memrefs.
897aed43562SMatthias Springer     rewriter.replaceOpWithMultiple(op, {desc.getFields()});
8980c7abd39SAart Bik     return success();
8990c7abd39SAart Bik   }
9000c7abd39SAart Bik };
9010c7abd39SAart Bik 
9028a583bd5Sbixia1 /// Sparse codegen rule for the expand op.
9038a583bd5Sbixia1 class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
9048a583bd5Sbixia1 public:
9058a583bd5Sbixia1   using OpConversionPattern::OpConversionPattern;
9068a583bd5Sbixia1   LogicalResult
907*9df63b26SMatthias Springer   matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
9088a583bd5Sbixia1                   ConversionPatternRewriter &rewriter) const override {
909191c43f6SPeiming Liu     if (!getSparseTensorEncoding(op.getTensor().getType()))
910191c43f6SPeiming Liu       return failure();
9118a583bd5Sbixia1     Location loc = op->getLoc();
912204234a6SMatthias Springer     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
913204234a6SMatthias Springer                                              op.getTensor().getType());
914f708a549Swren romano     const auto srcType = getSparseTensorType(op.getTensor());
9158a583bd5Sbixia1     Type eltType = srcType.getElementType();
9168a583bd5Sbixia1     Type boolType = rewriter.getIntegerType(1);
9178a583bd5Sbixia1     Type idxType = rewriter.getIndexType();
9188a583bd5Sbixia1     // All initialization should be done on entry of the loop nest.
9198a583bd5Sbixia1     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
920c780352dSPeiming Liu 
9218a583bd5Sbixia1     // Determine the size for access expansion (always the innermost stored
922c780352dSPeiming Liu     // level size).
923c780352dSPeiming Liu     const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
9248a583bd5Sbixia1     // Generate a memref for `sz` elements of type `t`.
925f708a549Swren romano     const auto genAlloc = [&](Type t) {
926f708a549Swren romano       const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
92783a50839SPeiming Liu       return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
9288a583bd5Sbixia1     };
9293986c869SAart Bik     // Allocate temporary buffers for values/filled-switch and added.
9308a583bd5Sbixia1     // We do not use stack buffers for this, since the expanded size may
9318a583bd5Sbixia1     // be rather large (as it envelops a single expanded dense dimension).
9328a583bd5Sbixia1     Value values = genAlloc(eltType);
9338a583bd5Sbixia1     Value filled = genAlloc(boolType);
9343986c869SAart Bik     Value added = genAlloc(idxType);
9358a583bd5Sbixia1     Value zero = constantZero(rewriter, loc, idxType);
9368a583bd5Sbixia1     // Reset the values/filled-switch to all-zero/false. Note that this
9378a583bd5Sbixia1     // introduces an O(N) operation into the computation, but this reset
9388a583bd5Sbixia1     // operation is amortized over the innermost loops for the access
9398a583bd5Sbixia1     // pattern expansion. As noted in the operation doc, we would like
9408a583bd5Sbixia1     // to amortize this setup cost even between kernels.
9418a583bd5Sbixia1     rewriter.create<linalg::FillOp>(
9428a583bd5Sbixia1         loc, ValueRange{constantZero(rewriter, loc, eltType)},
9438a583bd5Sbixia1         ValueRange{values});
9448a583bd5Sbixia1     rewriter.create<linalg::FillOp>(
9458a583bd5Sbixia1         loc, ValueRange{constantZero(rewriter, loc, boolType)},
9468a583bd5Sbixia1         ValueRange{filled});
94784cd51bbSwren romano     // Replace expansion op with these buffers and initial coordinate.
9488a583bd5Sbixia1     assert(op.getNumResults() == 4);
9493986c869SAart Bik     rewriter.replaceOp(op, {values, filled, added, zero});
9503986c869SAart Bik     return success();
9513986c869SAart Bik   }
9523986c869SAart Bik };
9533986c869SAart Bik 
9543986c869SAart Bik /// Sparse codegen rule for the compress operator.
9553986c869SAart Bik class SparseCompressConverter : public OpConversionPattern<CompressOp> {
9563986c869SAart Bik public:
9573986c869SAart Bik   using OpConversionPattern::OpConversionPattern;
9583986c869SAart Bik   LogicalResult
959*9df63b26SMatthias Springer   matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
9603986c869SAart Bik                   ConversionPatternRewriter &rewriter) const override {
9613986c869SAart Bik     Location loc = op->getLoc();
962191c43f6SPeiming Liu     SmallVector<Value> fields;
963204234a6SMatthias Springer     auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
964204234a6SMatthias Springer                                                 op.getTensor().getType());
965*9df63b26SMatthias Springer     Value values = getSingleValue(adaptor.getValues());
966*9df63b26SMatthias Springer     Value filled = getSingleValue(adaptor.getFilled());
967*9df63b26SMatthias Springer     Value added = getSingleValue(adaptor.getAdded());
968*9df63b26SMatthias Springer     Value count = getSingleValue(adaptor.getCount());
969f708a549Swren romano     const SparseTensorType dstType(desc.getRankedTensorType());
970191c43f6SPeiming Liu     Type eltType = dstType.getElementType();
971ad469385SPeiming Liu 
97284cd51bbSwren romano     // If the innermost level is ordered, we need to sort the coordinates
9734d068619SAart Bik     // in the "added" array prior to applying the compression.
974f708a549Swren romano     if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
9750083f833SPeiming Liu       rewriter.create<SortOp>(
976bfa3bc43SPeiming Liu           loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1),
977bfa3bc43SPeiming Liu           rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
9783986c869SAart Bik     // While performing the insertions, we also need to reset the elements
9793986c869SAart Bik     // of the values/filled-switch by only iterating over the set elements,
9803986c869SAart Bik     // to ensure that the runtime complexity remains proportional to the
9813986c869SAart Bik     // sparsity of the expanded access pattern.
9823986c869SAart Bik     //
9833986c869SAart Bik     // Generate
984d22df0ebSAart Bik     //    out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
98584cd51bbSwren romano     //      crd = added[i];
98684cd51bbSwren romano     //      value = values[crd];
98784cd51bbSwren romano     //      insert({lvlCoords, crd}, value);
98884cd51bbSwren romano     //      new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value);
98984cd51bbSwren romano     //      values[crd] = 0;
99084cd51bbSwren romano     //      filled[crd] = false;
991d22df0ebSAart Bik     //      yield new_memrefs
9923986c869SAart Bik     //    }
993191c43f6SPeiming Liu     scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
994d22df0ebSAart Bik     Value i = loop.getInductionVar();
995ad469385SPeiming Liu 
99684cd51bbSwren romano     Value crd = genLoad(rewriter, loc, added, i);
99784cd51bbSwren romano     Value value = genLoad(rewriter, loc, values, crd);
998ad469385SPeiming Liu     SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
999ad469385SPeiming Liu     SmallVector<Type> flatSpTensorTps = llvm::to_vector(
1000ad469385SPeiming Liu         llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
1001*9df63b26SMatthias Springer     SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords());
1002*9df63b26SMatthias Springer     params.append(flatLvlCoords.begin(), flatLvlCoords.end());
1003ad469385SPeiming Liu     params.push_back(crd);
1004ad469385SPeiming Liu     params.push_back(value);
1005ad469385SPeiming Liu     SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1006ad469385SPeiming Liu                                     params, /*genCall=*/true);
1007ad469385SPeiming Liu     SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
100884cd51bbSwren romano     genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd);
100984cd51bbSwren romano     genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd);
1010ad469385SPeiming Liu     rewriter.create<scf::YieldOp>(loc, insertRet);
1011ad469385SPeiming Liu 
1012129177eaSAart Bik     rewriter.setInsertionPointAfter(loop);
10133986c869SAart Bik     // Deallocate the buffers on exit of the full loop nest.
10145661647eSAart Bik     Operation *parent = getTop(op);
10153986c869SAart Bik     rewriter.setInsertionPointAfter(parent);
10163986c869SAart Bik     rewriter.create<memref::DeallocOp>(loc, values);
10173986c869SAart Bik     rewriter.create<memref::DeallocOp>(loc, filled);
10183986c869SAart Bik     rewriter.create<memref::DeallocOp>(loc, added);
1019d22df0ebSAart Bik     // Replace operation with resulting memrefs.
1020aed43562SMatthias Springer     rewriter.replaceOpWithMultiple(op, {loop->getResults()});
10219f596a7cSAart Bik     return success();
10229f596a7cSAart Bik   }
10239f596a7cSAart Bik };
10249f596a7cSAart Bik 
10259f596a7cSAart Bik /// Sparse codegen rule for the insert operator.
102694e27c26SPeiming Liu class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
10279f596a7cSAart Bik public:
10289f596a7cSAart Bik   using OpConversionPattern::OpConversionPattern;
10299f596a7cSAart Bik   LogicalResult
1030*9df63b26SMatthias Springer   matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
10319f596a7cSAart Bik                   ConversionPatternRewriter &rewriter) const override {
1032*9df63b26SMatthias Springer     auto stt = getSparseTensorType(op.getDest());
103394e27c26SPeiming Liu     if (!stt.hasEncoding())
103494e27c26SPeiming Liu       return failure();
103594e27c26SPeiming Liu     assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
103694e27c26SPeiming Liu 
1037ad469385SPeiming Liu     Location loc = op.getLoc();
1038204234a6SMatthias Springer     auto desc =
1039204234a6SMatthias Springer         getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
1040ad469385SPeiming Liu     TypeRange flatSpTensorTps = desc.getFields().getTypes();
1041ad469385SPeiming Liu     SmallVector<Value> params = llvm::to_vector(desc.getFields());
1042*9df63b26SMatthias Springer     SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
1043*9df63b26SMatthias Springer     params.append(flatIndices.begin(), flatIndices.end());
1044*9df63b26SMatthias Springer     params.push_back(getSingleValue(adaptor.getScalar()));
104594e27c26SPeiming Liu     SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1046ad469385SPeiming Liu                                     params, /*genCall=*/true);
1047ad469385SPeiming Liu     SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
1048d22df0ebSAart Bik     // Replace operation with resulting memrefs.
1049aed43562SMatthias Springer     rewriter.replaceOpWithMultiple(op, {ret});
10508a583bd5Sbixia1     return success();
10518a583bd5Sbixia1   }
10528a583bd5Sbixia1 };
10538a583bd5Sbixia1 
105484cd51bbSwren romano /// Sparse codegen rule for position accesses.
105584cd51bbSwren romano class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
10566607fdf7SAart Bik public:
105784cd51bbSwren romano   using OpAdaptor = typename ToPositionsOp::Adaptor;
105884cd51bbSwren romano   using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
10596607fdf7SAart Bik   LogicalResult
1060*9df63b26SMatthias Springer   matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
10616607fdf7SAart Bik                   ConversionPatternRewriter &rewriter) const override {
106284cd51bbSwren romano     // Replace the requested position access with corresponding field.
10635c511655SAart Bik     // The view is restricted to the actual size to ensure clients
10645c511655SAart Bik     // of this operation truly observe size, not capacity!
10655c511655SAart Bik     Location loc = op.getLoc();
10665c511655SAart Bik     Level lvl = op.getLevel();
1067204234a6SMatthias Springer     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1068204234a6SMatthias Springer                                              op.getTensor().getType());
10695c511655SAart Bik     auto mem = desc.getPosMemRef(lvl);
10705c511655SAart Bik     auto size = desc.getPosMemSize(rewriter, loc, lvl);
10715c511655SAart Bik     rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
10726607fdf7SAart Bik     return success();
10736607fdf7SAart Bik   }
10746607fdf7SAart Bik };
10756607fdf7SAart Bik 
107684cd51bbSwren romano /// Sparse codegen rule for accessing the coordinates arrays.
107784cd51bbSwren romano class SparseToCoordinatesConverter
107884cd51bbSwren romano     : public OpConversionPattern<ToCoordinatesOp> {
1079edca72f5SPeiming Liu public:
108084cd51bbSwren romano   using OpAdaptor = typename ToCoordinatesOp::Adaptor;
108184cd51bbSwren romano   using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
108290aa4362Sbixia1   LogicalResult
1083*9df63b26SMatthias Springer   matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
108490aa4362Sbixia1                   ConversionPatternRewriter &rewriter) const override {
108584cd51bbSwren romano     // Replace the requested coordinates access with corresponding field.
10865c511655SAart Bik     // The view is restricted to the actual size to ensure clients
10875c511655SAart Bik     // of this operation truly observe size, not capacity!
10885c511655SAart Bik     Location loc = op.getLoc();
10895c511655SAart Bik     Level lvl = op.getLevel();
1090204234a6SMatthias Springer     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1091204234a6SMatthias Springer                                              op.getTensor().getType());
10925c511655SAart Bik     auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
10935c511655SAart Bik     if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
10945c511655SAart Bik       auto size = desc.getCrdMemSize(rewriter, loc, lvl);
10955c511655SAart Bik       mem = genSliceToSize(rewriter, loc, mem, size);
10965c511655SAart Bik     }
10975c511655SAart Bik     rewriter.replaceOp(op, mem);
109890aa4362Sbixia1     return success();
1099edca72f5SPeiming Liu   }
1100edca72f5SPeiming Liu };
1101edca72f5SPeiming Liu 
110284cd51bbSwren romano /// Sparse codegen rule for accessing the linear coordinates buffer.
110384cd51bbSwren romano class SparseToCoordinatesBufferConverter
110484cd51bbSwren romano     : public OpConversionPattern<ToCoordinatesBufferOp> {
110581e3079dSbixia1 public:
110684cd51bbSwren romano   using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
110784cd51bbSwren romano   using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
110881e3079dSbixia1   LogicalResult
1109*9df63b26SMatthias Springer   matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
111081e3079dSbixia1                   ConversionPatternRewriter &rewriter) const override {
111184cd51bbSwren romano     // Replace the requested coordinates access with corresponding field.
11125c511655SAart Bik     // The view is restricted to the actual size to ensure clients
11135c511655SAart Bik     // of this operation truly observe size, not capacity!
11145c511655SAart Bik     Location loc = op.getLoc();
11155c511655SAart Bik     Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
1116204234a6SMatthias Springer     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1117204234a6SMatthias Springer                                              op.getTensor().getType());
11185c511655SAart Bik     auto mem = desc.getAOSMemRef();
11195c511655SAart Bik     auto size = desc.getCrdMemSize(rewriter, loc, lvl);
11205c511655SAart Bik     rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
112181e3079dSbixia1     return success();
112281e3079dSbixia1   }
112381e3079dSbixia1 };
112481e3079dSbixia1 
1125edca72f5SPeiming Liu /// Sparse codegen rule for value accesses.
112690aa4362Sbixia1 class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
1127edca72f5SPeiming Liu public:
112890aa4362Sbixia1   using OpAdaptor = typename ToValuesOp::Adaptor;
112990aa4362Sbixia1   using OpConversionPattern<ToValuesOp>::OpConversionPattern;
113090aa4362Sbixia1   LogicalResult
1131*9df63b26SMatthias Springer   matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
113290aa4362Sbixia1                   ConversionPatternRewriter &rewriter) const override {
113384cd51bbSwren romano     // Replace the requested values access with corresponding field.
11345c511655SAart Bik     // The view is restricted to the actual size to ensure clients
11355c511655SAart Bik     // of this operation truly observe size, not capacity!
11365c511655SAart Bik     Location loc = op.getLoc();
1137204234a6SMatthias Springer     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1138204234a6SMatthias Springer                                              op.getTensor().getType());
11395c511655SAart Bik     auto mem = desc.getValMemRef();
11405c511655SAart Bik     auto size = desc.getValMemSize(rewriter, loc);
11415c511655SAart Bik     rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
114290aa4362Sbixia1     return success();
1143edca72f5SPeiming Liu   }
1144edca72f5SPeiming Liu };
1145edca72f5SPeiming Liu 
114658b449c3Sbixia1 /// Sparse codegen rule for the convert operator.
114758b449c3Sbixia1 class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
114858b449c3Sbixia1 public:
114958b449c3Sbixia1   using OpConversionPattern::OpConversionPattern;
115058b449c3Sbixia1   LogicalResult
1151*9df63b26SMatthias Springer   matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
115258b449c3Sbixia1                   ConversionPatternRewriter &rewriter) const override {
11530128f801Sbixia1     SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
11540128f801Sbixia1     SparseTensorEncodingAttr encSrc =
11550128f801Sbixia1         getSparseTensorEncoding(op.getSource().getType());
115633267f40SPeiming Liu     // The output tensor can not be a slice and those cases should have been
115733267f40SPeiming Liu     // rejected by ConvertOp::verify() already.
115833267f40SPeiming Liu     assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices.");
115985dbb3fcSPeiming Liu     // Different encoding (except for different bitwidth) should be handled by
116085dbb3fcSPeiming Liu     // rewriting.
116133267f40SPeiming Liu     // We need further rewrites if the input tensor is a slice too.
116233267f40SPeiming Liu     if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
116333267f40SPeiming Liu         encSrc.isSlice()) {
116458b449c3Sbixia1       return failure();
116558b449c3Sbixia1     }
116685dbb3fcSPeiming Liu 
116785dbb3fcSPeiming Liu     Type retElemTp = op.getResult().getType().getElementType();
116885dbb3fcSPeiming Liu     Type srcElemTp = op.getSource().getType().getElementType();
116985dbb3fcSPeiming Liu     // Fold the trivial cases.
117085dbb3fcSPeiming Liu     if (retElemTp == srcElemTp && encDst == encSrc) {
1171*9df63b26SMatthias Springer       rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
117258b449c3Sbixia1       return success();
117358b449c3Sbixia1     }
117485dbb3fcSPeiming Liu     //
117585dbb3fcSPeiming Liu     // Do element-wise type conversion without using InsertOp.
117685dbb3fcSPeiming Liu     //
117785dbb3fcSPeiming Liu     // for each memref in srcTensor:
117885dbb3fcSPeiming Liu     //   dst = memref.alloc
117985dbb3fcSPeiming Liu     //   if srcMemRefType != dstMemRefType:
118085dbb3fcSPeiming Liu     //     for every dst[i] = cast(src[i])
118185dbb3fcSPeiming Liu     //   else:
118285dbb3fcSPeiming Liu     //     dst = memref.copy(src)
118385dbb3fcSPeiming Liu     Location loc = op.getLoc();
1184204234a6SMatthias Springer     auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
1185204234a6SMatthias Springer                                                 op.getSource().getType());
118685dbb3fcSPeiming Liu     SmallVector<Value> fields;
118785dbb3fcSPeiming Liu     foreachFieldAndTypeInSparseTensor(
11885550c821STres Popp         SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
118985dbb3fcSPeiming Liu         [&rewriter, &fields, srcDesc,
119085dbb3fcSPeiming Liu          loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
11911944c4f7SAart Bik               LevelType /*lt*/) -> bool {
119285dbb3fcSPeiming Liu           // Simply reuses the storage specifier as it is an SSA value.
119385dbb3fcSPeiming Liu           if (fKind == SparseTensorFieldKind::StorageSpec) {
119485dbb3fcSPeiming Liu             fields.push_back(srcDesc.getSpecifier());
119585dbb3fcSPeiming Liu           } else {
119685dbb3fcSPeiming Liu             // Allocates new memrefs
119785dbb3fcSPeiming Liu             Value srcMem = srcDesc.getMemRefField(fIdx);
119885dbb3fcSPeiming Liu             // TODO: We can instead use the actual memSize in specifier, that
119985dbb3fcSPeiming Liu             // would require a subViewOp to avoid overflow when copying
120085dbb3fcSPeiming Liu             // values.
120185dbb3fcSPeiming Liu             Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
120285dbb3fcSPeiming Liu             auto dstMem = rewriter.create<memref::AllocOp>(
12035550c821STres Popp                 loc, cast<MemRefType>(fTp), sz);
120485dbb3fcSPeiming Liu             if (fTp != srcMem.getType()) {
120585dbb3fcSPeiming Liu               // Converts elements type.
120685dbb3fcSPeiming Liu               scf::buildLoopNest(
120785dbb3fcSPeiming Liu                   rewriter, loc, constantIndex(rewriter, loc, 0), sz,
120885dbb3fcSPeiming Liu                   constantIndex(rewriter, loc, 1),
120985dbb3fcSPeiming Liu                   [srcMem, &dstMem](OpBuilder &builder, Location loc,
121085dbb3fcSPeiming Liu                                     ValueRange ivs) {
121185dbb3fcSPeiming Liu                     Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
121285dbb3fcSPeiming Liu                     Value casted = genCast(builder, loc, v,
121385dbb3fcSPeiming Liu                                            dstMem.getType().getElementType());
121485dbb3fcSPeiming Liu                     builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
121585dbb3fcSPeiming Liu                   });
121685dbb3fcSPeiming Liu             } else {
121785dbb3fcSPeiming Liu               // TODO: We can even reuse the same memref for the new tensor,
121885dbb3fcSPeiming Liu               // but that requires a `ref-counting` based memory management
121985dbb3fcSPeiming Liu               // for shared memrefs between multiple sparse tensors.
122085dbb3fcSPeiming Liu               rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
122185dbb3fcSPeiming Liu             }
122285dbb3fcSPeiming Liu             fields.push_back(dstMem);
122385dbb3fcSPeiming Liu           }
122485dbb3fcSPeiming Liu           return true;
122585dbb3fcSPeiming Liu         });
122685dbb3fcSPeiming Liu 
1227aed43562SMatthias Springer     rewriter.replaceOpWithMultiple(op, {fields});
122885dbb3fcSPeiming Liu     return success();
122985dbb3fcSPeiming Liu   }
123058b449c3Sbixia1 };
123158b449c3Sbixia1 
12326db397a8SPeiming Liu class SparseExtractSliceConverter
123303526904SPeiming Liu     : public OpConversionPattern<tensor::ExtractSliceOp> {
123403526904SPeiming Liu public:
123503526904SPeiming Liu   using OpConversionPattern::OpConversionPattern;
123603526904SPeiming Liu   LogicalResult
1237*9df63b26SMatthias Springer   matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
123803526904SPeiming Liu                   ConversionPatternRewriter &rewriter) const override {
12396db397a8SPeiming Liu     Location loc = op.getLoc();
12406db397a8SPeiming Liu     MLIRContext *ctx = op.getContext();
124103526904SPeiming Liu     auto srcEnc = getSparseTensorEncoding(op.getSourceType());
124203526904SPeiming Liu     auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
124303526904SPeiming Liu     // TODO: We should check these in ExtractSliceOp::verify.
1244dbdb4affSAart Bik     if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1245dbdb4affSAart Bik       return failure();
1246af2bec7cSwren romano     assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
124703526904SPeiming Liu 
12486db397a8SPeiming Liu     SmallVector<Value> fields;
1249204234a6SMatthias Springer     auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
1250204234a6SMatthias Springer                                                 op.getSource().getType());
12516db397a8SPeiming Liu 
12526db397a8SPeiming Liu     auto newSpec = rewriter.create<StorageSpecifierInitOp>(
12536db397a8SPeiming Liu         loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
12546db397a8SPeiming Liu     desc.setSpecifier(newSpec);
12556db397a8SPeiming Liu 
12566db397a8SPeiming Liu     // Fills in slice information.
1257a0a76804SJakub Kuderski     for (auto [idx, offset, size, stride] : llvm::enumerate(
1258a0a76804SJakub Kuderski              op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1259a0a76804SJakub Kuderski       Dimension dim = idx;
12606db397a8SPeiming Liu 
12616db397a8SPeiming Liu       Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
12626db397a8SPeiming Liu       Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
12636db397a8SPeiming Liu       Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride);
12646db397a8SPeiming Liu       // TODO: We could probably only set dynamic value here. But it would
12656db397a8SPeiming Liu       // requires us to fill the hole when casting a static slice to dynamic
12666db397a8SPeiming Liu       // slice.
12676db397a8SPeiming Liu       desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset,
12686db397a8SPeiming Liu                              dim, offsetV);
12696db397a8SPeiming Liu 
12706db397a8SPeiming Liu       // FIXME: we need to distinguish level sizes and dimension size for slices
12716db397a8SPeiming Liu       // here. Maybe we should store slice level sizes in a different array
12726db397a8SPeiming Liu       // instead of reusing it.
127376647fceSwren romano       assert(srcEnc.isIdentity());
12746db397a8SPeiming Liu       desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
12756db397a8SPeiming Liu                              sizeV);
12766db397a8SPeiming Liu       desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
12776db397a8SPeiming Liu                              dim, strideV);
127803526904SPeiming Liu     }
127903526904SPeiming Liu 
12806db397a8SPeiming Liu     // NOTE: we can not generate tuples directly from descriptor here, as the
12816db397a8SPeiming Liu     // descriptor is holding the original type, yet we want the slice type
12826db397a8SPeiming Liu     // here (they shared every memref but with an updated specifier).
1283aed43562SMatthias Springer     rewriter.replaceOpWithMultiple(op, {desc.getFields()});
128403526904SPeiming Liu     return success();
128503526904SPeiming Liu   }
128603526904SPeiming Liu };
128703526904SPeiming Liu 
12880f3e4d1aSAart Bik /// Sparse codegen rule for number of entries operator.
12890f3e4d1aSAart Bik class SparseNumberOfEntriesConverter
12900f3e4d1aSAart Bik     : public OpConversionPattern<NumberOfEntriesOp> {
12910f3e4d1aSAart Bik public:
12920f3e4d1aSAart Bik   using OpConversionPattern::OpConversionPattern;
12930f3e4d1aSAart Bik   LogicalResult
1294*9df63b26SMatthias Springer   matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
12950f3e4d1aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
129663d31a4dSbixia1     // Query memSizes for the actually stored values.
1297de560888SPeiming Liu     // FIXME: the nse value computed in this way might be wrong when there is
1298d2e85179SYinying Li     // any "loose_compressed" level.
1299204234a6SMatthias Springer     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1300204234a6SMatthias Springer                                              op.getTensor().getType());
1301204234a6SMatthias Springer     rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
13020f3e4d1aSAart Bik     return success();
13030f3e4d1aSAart Bik   }
13040f3e4d1aSAart Bik };
13050f3e4d1aSAart Bik 
13066ca47eb4SPeiming Liu struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
1307a41672e1SPeiming Liu   using OpConversionPattern::OpConversionPattern;
1308a41672e1SPeiming Liu   LogicalResult
13096ca47eb4SPeiming Liu   matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1310a41672e1SPeiming Liu                   ConversionPatternRewriter &rewriter) const override {
1311de560888SPeiming Liu     Location loc = op.getLoc();
131234c9c59cSwren romano     const auto stt = getSparseTensorType(op.getResult());
1313a41672e1SPeiming Liu 
1314a41672e1SPeiming Liu     SmallVector<Value> fields;
1315a41672e1SPeiming Liu 
1316a41672e1SPeiming Liu     foreachFieldAndTypeInSparseTensor(
131734c9c59cSwren romano         stt,
1318de560888SPeiming Liu         [&rewriter, &fields, &op, &stt,
1319f708a549Swren romano          loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
13201944c4f7SAart Bik               Level /*lvl*/, LevelType lt) -> bool {
1321a41672e1SPeiming Liu           assert(fields.size() == fIdx);
1322de560888SPeiming Liu           if (fKind == SparseTensorFieldKind::StorageSpec) {
1323de560888SPeiming Liu             fields.push_back(
1324de560888SPeiming Liu                 SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
13257864d736SPeiming Liu           } else {
1326de560888SPeiming Liu             // Else simply takes the inputs.
1327b2e6b735SPeiming Liu             Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1328de560888SPeiming Liu                                ? op.getValues()
1329de560888SPeiming Liu                                : op.getLevels()[fIdx];
13300d1f9576SPeiming Liu             // TODO: handle batch.
1331b2e6b735SPeiming Liu             TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
133252b69aa3SPeiming Liu             if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
133352b69aa3SPeiming Liu               // Flattens the buffer to batchLvlRank.
133452b69aa3SPeiming Liu               auto reassoc = getReassociationForFlattening(
133552b69aa3SPeiming Liu                   mem.getType(), stt.getBatchLvlRank());
1336b2e6b735SPeiming Liu               mem = rewriter.create<memref::CastOp>(
1337b2e6b735SPeiming Liu                   loc, fType,
1338b2e6b735SPeiming Liu                   rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
1339b2e6b735SPeiming Liu             } else {
1340b2e6b735SPeiming Liu               mem = rewriter.create<memref::CastOp>(loc, fType, mem);
1341a41672e1SPeiming Liu             }
1342b2e6b735SPeiming Liu             fields.push_back(mem);
1343de560888SPeiming Liu           }
1344a41672e1SPeiming Liu           return true;
1345a41672e1SPeiming Liu         });
1346a41672e1SPeiming Liu 
134734c9c59cSwren romano     MutSparseTensorDescriptor desc(stt, fields);
1348f7b8b005SPeiming Liu     Value c0 = constantIndex(rewriter, loc, 0);
1349de560888SPeiming Liu     Value c1 = constantIndex(rewriter, loc, 1);
1350de560888SPeiming Liu     Value c2 = constantIndex(rewriter, loc, 2);
1351be556ee1SYinying Li     Value posBack = c0; // index to the last value in the position array
1352f7b8b005SPeiming Liu     Value memSize = c1; // memory size for current array
1353b2e6b735SPeiming Liu 
13545248a987SPeiming Liu     Level trailCOOStart = stt.getAoSCOOStart();
1355b2e6b735SPeiming Liu     Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1356de560888SPeiming Liu     // Sets up SparseTensorSpecifier.
135734c9c59cSwren romano     for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1358de560888SPeiming Liu       assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1359a41672e1SPeiming Liu 
1360de560888SPeiming Liu       // Sets up the level size.
13610d1f9576SPeiming Liu       auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1362de560888SPeiming Liu       desc.setLvlSize(rewriter, loc, lvl, lvlSize);
1363b2e6b735SPeiming Liu       // We use a single AOS array to store the trailing COO, so there is only
1364b2e6b735SPeiming Liu       // one memory size to set for the entire COO section.
1365b2e6b735SPeiming Liu       if (lvl > trailCOOStart)
1366b2e6b735SPeiming Liu         continue;
1367de560888SPeiming Liu 
1368de560888SPeiming Liu       // Sets up the memory size by reading the last value in position array.
13691944c4f7SAart Bik       LevelType lt = stt.getLvlType(lvl);
1370de560888SPeiming Liu       // Simply forwards the position index when this is a dense level.
137152b69aa3SPeiming Liu       if (lt.isa<LevelFormat::Dense>()) {
1372f7b8b005SPeiming Liu         memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
1373de560888SPeiming Liu         posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1374de560888SPeiming Liu         continue;
1375a41672e1SPeiming Liu       }
137652b69aa3SPeiming Liu       if (lt.isa<LevelFormat::Batch>()) {
137752b69aa3SPeiming Liu         // Skips batch levels as it is not linearized.
137852b69aa3SPeiming Liu         // FIXME: this assumes that every batch has the same number of nse, need
137952b69aa3SPeiming Liu         // to be generalized to handle varied-size batches.
138052b69aa3SPeiming Liu         continue;
138152b69aa3SPeiming Liu       }
1382de560888SPeiming Liu 
13831dd387e1SAart Bik       if (isWithPosLT(lt)) {
13841dd387e1SAart Bik         assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
13851dd387e1SAart Bik         if (isLooseCompressedLT(lt)) {
1386de560888SPeiming Liu           memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
1387de560888SPeiming Liu           posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1388f7b8b005SPeiming Liu         } else {
13891dd387e1SAart Bik           assert(isCompressedLT(lt));
1390f7b8b005SPeiming Liu           posBack = memSize;
1391f7b8b005SPeiming Liu           memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1);
1392de560888SPeiming Liu         }
1393de560888SPeiming Liu         desc.setPosMemSize(rewriter, loc, lvl, memSize);
1394de560888SPeiming Liu         // The last value in position array is the memory size for next level.
139552b69aa3SPeiming Liu         // FIXME: this assumes that every batch has the same number of nse, need
139652b69aa3SPeiming Liu         // to be generalized to handle varied-size batches.
139752b69aa3SPeiming Liu         SmallVector<Value> batched(stt.getBatchLvlRank(),
139852b69aa3SPeiming Liu                                    constantIndex(rewriter, loc, 0));
139952b69aa3SPeiming Liu         batched.push_back(posBack);
140052b69aa3SPeiming Liu         memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
1401de560888SPeiming Liu         posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
1402de560888SPeiming Liu       }
14031dd387e1SAart Bik       assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
1404b2e6b735SPeiming Liu       // FIXME: This seems to be unnecessarily complex, can we simplify it?
1405b2e6b735SPeiming Liu       if (lvl == trailCOOStart) {
1406b2e6b735SPeiming Liu         Value cooSz = rewriter.create<arith::MulIOp>(
1407b2e6b735SPeiming Liu             loc, memSize, constantIndex(rewriter, loc, trailCOORank));
1408b2e6b735SPeiming Liu         desc.setCrdMemSize(rewriter, loc, lvl, cooSz);
1409b2e6b735SPeiming Liu       } else {
1410de560888SPeiming Liu         desc.setCrdMemSize(rewriter, loc, lvl, memSize);
1411de560888SPeiming Liu       }
1412b2e6b735SPeiming Liu     }
1413de560888SPeiming Liu     desc.setValMemSize(rewriter, loc, memSize);
1414a41672e1SPeiming Liu 
1415aed43562SMatthias Springer     rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1416a41672e1SPeiming Liu     return success();
1417a41672e1SPeiming Liu   }
1418a41672e1SPeiming Liu };
1419a41672e1SPeiming Liu 
14206ca47eb4SPeiming Liu struct SparseDisassembleOpConverter
14216ca47eb4SPeiming Liu     : public OpConversionPattern<DisassembleOp> {
1422d4db5289SPeiming Liu   using OpConversionPattern::OpConversionPattern;
1423206fad0eSMatthias Springer   SparseDisassembleOpConverter(const TypeConverter &typeConverter,
14246ca47eb4SPeiming Liu                                MLIRContext *context)
1425de560888SPeiming Liu       : OpConversionPattern(typeConverter, context) {}
1426d4db5289SPeiming Liu 
1427d4db5289SPeiming Liu   LogicalResult
1428*9df63b26SMatthias Springer   matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
1429d4db5289SPeiming Liu                   ConversionPatternRewriter &rewriter) const override {
1430204234a6SMatthias Springer     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1431204234a6SMatthias Springer                                              op.getTensor().getType());
1432b2e6b735SPeiming Liu     Location loc = op.getLoc();
1433b2e6b735SPeiming Liu     SmallVector<Value> retMem;
1434a63d6a00SPeiming Liu     SmallVector<Value> retLen;
14351944c4f7SAart Bik     desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem,
14361944c4f7SAart Bik                                    &retLen](FieldIndex fid,
14371944c4f7SAart Bik                                             SparseTensorFieldKind fKind,
14381944c4f7SAart Bik                                             Level lvl, LevelType lt) -> bool {
1439b2e6b735SPeiming Liu       if (fKind == SparseTensorFieldKind::StorageSpec)
1440b2e6b735SPeiming Liu         return true;
1441b2e6b735SPeiming Liu       SparseTensorType stt(desc.getRankedTensorType());
1442b2e6b735SPeiming Liu       Value sz, src;
1443b2e6b735SPeiming Liu       TypedValue<BaseMemRefType> dst;
1444b2e6b735SPeiming Liu       if (fKind == SparseTensorFieldKind::ValMemRef) {
1445b2e6b735SPeiming Liu         sz = desc.getValMemSize(rewriter, loc);
1446b2e6b735SPeiming Liu         src = desc.getValMemRef();
1447b2e6b735SPeiming Liu         dst = genToMemref(rewriter, loc, op.getOutValues());
1448fc9f1d49SPeiming Liu 
1449fc9f1d49SPeiming Liu         retMem.push_back(dst);
145064df1c08SPeiming Liu         Type valLenTp = op.getValLen().getType();
1451fc9f1d49SPeiming Liu         retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
1452b2e6b735SPeiming Liu       } else {
1453b2e6b735SPeiming Liu         assert(fKind == SparseTensorFieldKind::PosMemRef ||
1454b2e6b735SPeiming Liu                fKind == SparseTensorFieldKind::CrdMemRef);
1455b2e6b735SPeiming Liu 
1456b2e6b735SPeiming Liu         sz = fKind == SparseTensorFieldKind::PosMemRef
1457b2e6b735SPeiming Liu                  ? desc.getPosMemSize(rewriter, loc, lvl)
1458b2e6b735SPeiming Liu                  : desc.getCrdMemSize(rewriter, loc, lvl);
1459b2e6b735SPeiming Liu         src = desc.getMemRefField(fid);
1460b2e6b735SPeiming Liu         dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1461b2e6b735SPeiming Liu         retMem.push_back(dst);
146264df1c08SPeiming Liu         // Retrieves the corresponding level length type.
146364df1c08SPeiming Liu         Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1464098f46dcSPeiming Liu         retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
1465b2e6b735SPeiming Liu       }
1466b2e6b735SPeiming Liu       Value flatOut = dst;
146752b69aa3SPeiming Liu       if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
146852b69aa3SPeiming Liu         auto reassoc =
146952b69aa3SPeiming Liu             getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
1470b2e6b735SPeiming Liu         flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
1471b2e6b735SPeiming Liu       }
1472b2e6b735SPeiming Liu       Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
1473b2e6b735SPeiming Liu       Value srcMem = genSliceToSize(rewriter, loc, src, sz);
1474b2e6b735SPeiming Liu       rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1475b2e6b735SPeiming Liu       return true;
1476b2e6b735SPeiming Liu     });
1477b2e6b735SPeiming Liu 
1478b2e6b735SPeiming Liu     // Converts MemRefs back to Tensors.
1479a63d6a00SPeiming Liu     SmallVector<Value> retValues = llvm::to_vector(
1480b2e6b735SPeiming Liu         llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
1481b2e6b735SPeiming Liu           return rewriter.create<bufferization::ToTensorOp>(loc, v);
1482b2e6b735SPeiming Liu         }));
1483a63d6a00SPeiming Liu     // Appends the actual memory length used in each buffer returned.
1484a63d6a00SPeiming Liu     retValues.append(retLen.begin(), retLen.end());
1485a63d6a00SPeiming Liu     rewriter.replaceOp(op, retValues);
1486b2e6b735SPeiming Liu     return success();
1487d4db5289SPeiming Liu   }
1488dc6427d6SPeiming Liu };
1489dc6427d6SPeiming Liu 
1490d3af6535SAart Bik struct SparseNewConverter : public OpConversionPattern<NewOp> {
14912c81d432Sbixia1   using OpConversionPattern::OpConversionPattern;
14922c81d432Sbixia1   LogicalResult
14932c81d432Sbixia1   matchAndRewrite(NewOp op, OpAdaptor adaptor,
14942c81d432Sbixia1                   ConversionPatternRewriter &rewriter) const override {
14952c81d432Sbixia1     Location loc = op.getLoc();
14962c81d432Sbixia1     const auto dstTp = getSparseTensorType(op.getResult());
14972c81d432Sbixia1     // Creating COO with NewOp is handled by direct IR codegen. All other cases
14982c81d432Sbixia1     // are handled by rewriting.
14995248a987SPeiming Liu     if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
15002c81d432Sbixia1       return failure();
15012c81d432Sbixia1 
1502d3af6535SAart Bik     // Implement as follows:
1503b86d3cbcSAart Bik     //   %reader = @createCheckedSparseTensorReader(%filename)
150484cd51bbSwren romano     //   %nse = @getSparseTensorNSE(%reader)
150584cd51bbSwren romano     //   %coo = bufferization.alloc_tensor an ordered COO with
150684cd51bbSwren romano     //          dst dim ordering, size_hint = %nse
150784cd51bbSwren romano     //   %coordinates = sparse_tensor.coordinates_buffer(%coo)
150884cd51bbSwren romano     //   %values = sparse_tensor.values(%coo)
150984cd51bbSwren romano     //   %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values)
151084cd51bbSwren romano     //   if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
15112c81d432Sbixia1     //   update storage specifier
151284cd51bbSwren romano     //   @delSparseTensorReader(%reader)
151383cf0dc9SAart Bik     SmallVector<Value> dimSizesValues;
1514d3af6535SAart Bik     Value dimSizesBuffer;
1515d3af6535SAart Bik     Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
151683cf0dc9SAart Bik                              dimSizesValues, dimSizesBuffer);
15172c81d432Sbixia1 
1518d3af6535SAart Bik     // Get the number of stored entries.
151984cd51bbSwren romano     const Type indexTp = rewriter.getIndexType();
1520d3af6535SAart Bik     Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
1521d3af6535SAart Bik                                {indexTp}, {reader}, EmitCInterface::Off)
1522d3af6535SAart Bik                     .getResult(0);
15232c81d432Sbixia1 
152483cf0dc9SAart Bik     // Construct the lvl sizes and the dim2lvl/lvl2dim buffers.
15252323f48eSAart Bik     SmallVector<Value> lvlSizesValues;
1526d3af6535SAart Bik     Value dim2lvlBuffer;
1527d3af6535SAart Bik     Value lvl2dimBuffer;
152883cf0dc9SAart Bik     genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
15292323f48eSAart Bik                   lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
15302c81d432Sbixia1 
153183cf0dc9SAart Bik     // Construct allocation for each field.
153283cf0dc9SAart Bik     Value sizeHint = nse;
153383cf0dc9SAart Bik     SmallVector<Value> fields;
153483cf0dc9SAart Bik     createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint,
153583cf0dc9SAart Bik                       lvlSizesValues, fields);
153683cf0dc9SAart Bik 
153784cd51bbSwren romano     // Read the COO tensor data.
1538160d483bSAart Bik     MutSparseTensorDescriptor desc(dstTp, fields);
153984cd51bbSwren romano     Value xs = desc.getAOSMemRef();
154084cd51bbSwren romano     Value ys = desc.getValMemRef();
154184cd51bbSwren romano     const Type boolTp = rewriter.getIntegerType(1);
154284cd51bbSwren romano     const Type elemTp = dstTp.getElementType();
154384cd51bbSwren romano     const Type crdTp = dstTp.getCrdType();
1544b86d3cbcSAart Bik     SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers",
154584cd51bbSwren romano                                           overheadTypeFunctionSuffix(crdTp),
154684cd51bbSwren romano                                           primaryTypeFunctionSuffix(elemTp)};
15472c81d432Sbixia1     Value isSorted =
154884cd51bbSwren romano         createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
1549d3af6535SAart Bik                        {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1550d3af6535SAart Bik                        EmitCInterface::On)
15512c81d432Sbixia1             .getResult(0);
15522c81d432Sbixia1 
15532c81d432Sbixia1     // If the destination tensor is a sorted COO, we need to sort the COO tensor
15542c81d432Sbixia1     // data if the input elements aren't sorted yet.
1555d3af6535SAart Bik     const Level lvlRank = dstTp.getLvlRank();
155684cd51bbSwren romano     if (dstTp.isOrderedLvl(lvlRank - 1)) {
155784cd51bbSwren romano       Value kFalse = constantI1(rewriter, loc, false);
15582c81d432Sbixia1       Value notSorted = rewriter.create<arith::CmpIOp>(
155984cd51bbSwren romano           loc, arith::CmpIPredicate::eq, isSorted, kFalse);
15602c81d432Sbixia1       scf::IfOp ifOp =
15612c81d432Sbixia1           rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
15622c81d432Sbixia1       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1563bfa3bc43SPeiming Liu       auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
15640083f833SPeiming Liu       rewriter.create<SortOp>(loc, nse, xs, ValueRange{ys}, xPerm,
1565bfa3bc43SPeiming Liu                               rewriter.getIndexAttr(0),
1566bfa3bc43SPeiming Liu                               SparseTensorSortKind::HybridQuickSort);
15672c81d432Sbixia1       rewriter.setInsertionPointAfter(ifOp);
15682c81d432Sbixia1     }
15692c81d432Sbixia1 
157084cd51bbSwren romano     // Set PosMemRef0[1] = nse.
157184cd51bbSwren romano     const Value c1 = constantIndex(rewriter, loc, 1);
157284cd51bbSwren romano     const Value posMemref0 = desc.getPosMemRef(0);
157384cd51bbSwren romano     const Type posTp = dstTp.getPosType();
157484cd51bbSwren romano     const Value posNse = genCast(rewriter, loc, nse, posTp);
157584cd51bbSwren romano     rewriter.create<memref::StoreOp>(loc, posNse, posMemref0, c1);
15762c81d432Sbixia1 
15772c81d432Sbixia1     // Update storage specifier.
157884cd51bbSwren romano     Value coordinatesSize = rewriter.create<arith::MulIOp>(
157984cd51bbSwren romano         loc, nse, constantIndex(rewriter, loc, lvlRank));
158084cd51bbSwren romano     desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0,
158184cd51bbSwren romano                            coordinatesSize);
15822c81d432Sbixia1     desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
15832c81d432Sbixia1                            std::nullopt, nse);
15842c81d432Sbixia1 
15852c81d432Sbixia1     // Release the sparse tensor reader.
15862c81d432Sbixia1     createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
15872c81d432Sbixia1                    EmitCInterface::Off);
15882c81d432Sbixia1 
15892c81d432Sbixia1     // Replace operation with resulting memrefs.
1590aed43562SMatthias Springer     rewriter.replaceOpWithMultiple(op, {fields});
15912c81d432Sbixia1     return success();
15922c81d432Sbixia1   }
15932c81d432Sbixia1 };
15942c81d432Sbixia1 
1595e8e8df4cSMatthias Springer struct SparseHasRuntimeLibraryConverter
1596e8e8df4cSMatthias Springer     : public OpConversionPattern<HasRuntimeLibraryOp> {
1597e8e8df4cSMatthias Springer   using OpConversionPattern::OpConversionPattern;
1598e8e8df4cSMatthias Springer   LogicalResult
1599e8e8df4cSMatthias Springer   matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1600e8e8df4cSMatthias Springer                   ConversionPatternRewriter &rewriter) const override {
1601e8e8df4cSMatthias Springer     auto i1Type = rewriter.getI1Type();
1602e8e8df4cSMatthias Springer     rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1603e8e8df4cSMatthias Springer         op, i1Type, rewriter.getIntegerAttr(i1Type, 0));
1604e8e8df4cSMatthias Springer     return success();
1605e8e8df4cSMatthias Springer   }
1606e8e8df4cSMatthias Springer };
1607e8e8df4cSMatthias Springer 
160886b22d31SAart Bik } // namespace
160986b22d31SAart Bik 
161086b22d31SAart Bik //===----------------------------------------------------------------------===//
161186b22d31SAart Bik // Public method for populating conversion rules.
161286b22d31SAart Bik //===----------------------------------------------------------------------===//
161386b22d31SAart Bik 
161486b22d31SAart Bik /// Populates the given patterns list with conversion rules required for
161586b22d31SAart Bik /// the sparsification of linear algebra operations.
16167276b643Sbixia1 void mlir::populateSparseTensorCodegenPatterns(
1617206fad0eSMatthias Springer     const TypeConverter &typeConverter, RewritePatternSet &patterns,
1618c44d307cSPeiming Liu     bool createSparseDeallocs, bool enableBufferInitialization) {
1619e8e8df4cSMatthias Springer   patterns.add<
1620e8e8df4cSMatthias Springer       SparseAssembleOpConverter, SparseDisassembleOpConverter,
1621c780352dSPeiming Liu       SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1622de560888SPeiming Liu       SparseCastConverter, SparseExtractSliceConverter,
1623e8e8df4cSMatthias Springer       SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1624e8e8df4cSMatthias Springer       SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
16256db397a8SPeiming Liu       SparseSliceGetterOpConverter<ToSliceOffsetOp,
16266db397a8SPeiming Liu                                    StorageSpecifierKind::DimOffset>,
16276db397a8SPeiming Liu       SparseSliceGetterOpConverter<ToSliceStrideOp,
16286db397a8SPeiming Liu                                    StorageSpecifierKind::DimStride>,
16296db397a8SPeiming Liu       SparseToPositionsConverter, SparseToCoordinatesConverter,
16306db397a8SPeiming Liu       SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1631d3af6535SAart Bik       SparseConvertConverter, SparseNewConverter,
1632e8e8df4cSMatthias Springer       SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1633e8e8df4cSMatthias Springer       typeConverter, patterns.getContext());
1634de560888SPeiming Liu   patterns.add<SparseTensorDeallocConverter>(
1635c44d307cSPeiming Liu       typeConverter, patterns.getContext(), createSparseDeallocs);
16363e4a8c2cSAart Bik   patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
16373e4a8c2cSAart Bik       typeConverter, patterns.getContext(), enableBufferInitialization);
163886b22d31SAart Bik }
1639