186b22d31SAart Bik //===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===// 286b22d31SAart Bik // 386b22d31SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 486b22d31SAart Bik // See https://llvm.org/LICENSE.txt for license information. 586b22d31SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 686b22d31SAart Bik // 786b22d31SAart Bik //===----------------------------------------------------------------------===// 886b22d31SAart Bik // 986b22d31SAart Bik // A pass that converts sparse tensor types and primitives to actual compiler 1086b22d31SAart Bik // visible buffers and actual compiler IR that implements these primitives on 1186b22d31SAart Bik // the selected sparse tensor storage schemes. This pass provides an alternative 1286b22d31SAart Bik // to the SparseTensorConversion pass, eliminating the dependence on a runtime 13bc61122aSAart Bik // support library (other than for file I/O), and providing many more 14bc61122aSAart Bik // opportunities for subsequent compiler optimization of the generated code. 1586b22d31SAart Bik // 1686b22d31SAart Bik //===----------------------------------------------------------------------===// 1786b22d31SAart Bik 18365777ecSAart Bik #include "Utils/CodegenUtils.h" 19365777ecSAart Bik #include "Utils/SparseTensorDescriptor.h" 2086b22d31SAart Bik 216db397a8SPeiming Liu #include "mlir/Dialect/Arith/Utils/Utils.h" 222ddfacd9SAart Bik #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 2386b22d31SAart Bik #include "mlir/Dialect/Func/IR/FuncOps.h" 248a583bd5Sbixia1 #include "mlir/Dialect/Linalg/Utils/Utils.h" 2586b22d31SAart Bik #include "mlir/Dialect/MemRef/IR/MemRef.h" 26840e2ba3Sbixia1 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 2786b22d31SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 28f708a549Swren romano #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 2986b22d31SAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 3086b22d31SAart Bik #include "mlir/Dialect/Tensor/IR/Tensor.h" 3186b22d31SAart Bik #include "mlir/Transforms/DialectConversion.h" 32f708a549Swren romano 33a1fe1f5fSKazu Hirata #include <optional> 3486b22d31SAart Bik 3586b22d31SAart Bik using namespace mlir; 3686b22d31SAart Bik using namespace mlir::sparse_tensor; 3786b22d31SAart Bik 3886b22d31SAart Bik //===----------------------------------------------------------------------===// 3986b22d31SAart Bik // Helper methods. 4086b22d31SAart Bik //===----------------------------------------------------------------------===// 4186b22d31SAart Bik 42*9df63b26SMatthias Springer /// Flatten the given value ranges into a single vector of values. 43*9df63b26SMatthias Springer static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { 44*9df63b26SMatthias Springer SmallVector<Value> result; 45*9df63b26SMatthias Springer for (const auto &vals : values) 46*9df63b26SMatthias Springer llvm::append_range(result, vals); 47*9df63b26SMatthias Springer return result; 48edca72f5SPeiming Liu } 49*9df63b26SMatthias Springer 50*9df63b26SMatthias Springer /// Assert that the given value range contains a single value and return it. 51*9df63b26SMatthias Springer static Value getSingleValue(ValueRange values) { 52*9df63b26SMatthias Springer assert(values.size() == 1 && "expected single value"); 53*9df63b26SMatthias Springer return values.front(); 545ab1a8aeSPeiming Liu } 55edca72f5SPeiming Liu 5684cd51bbSwren romano /// Generates a load with proper `index` typing. 5770633a8dSAart Bik static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) { 5844ff23d5SPeiming Liu idx = genCast(builder, loc, idx, builder.getIndexType()); 5970633a8dSAart Bik return builder.create<memref::LoadOp>(loc, mem, idx); 6070633a8dSAart Bik } 6170633a8dSAart Bik 6284cd51bbSwren romano /// Generates a store with proper `index` typing and proper value. 6370633a8dSAart Bik static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, 6470633a8dSAart Bik Value idx) { 6544ff23d5SPeiming Liu idx = genCast(builder, loc, idx, builder.getIndexType()); 6644ff23d5SPeiming Liu val = genCast(builder, loc, val, 675550c821STres Popp cast<ShapedType>(mem.getType()).getElementType()); 6870633a8dSAart Bik builder.create<memref::StoreOp>(loc, val, mem, idx); 6970633a8dSAart Bik } 7070633a8dSAart Bik 7170633a8dSAart Bik /// Creates a straightforward counting for-loop. 7270633a8dSAart Bik static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, 73191c43f6SPeiming Liu MutableArrayRef<Value> fields, 7470633a8dSAart Bik Value lower = Value()) { 7570633a8dSAart Bik Type indexType = builder.getIndexType(); 7670633a8dSAart Bik if (!lower) 7770633a8dSAart Bik lower = constantZero(builder, loc, indexType); 7870633a8dSAart Bik Value one = constantOne(builder, loc, indexType); 7970633a8dSAart Bik scf::ForOp forOp = builder.create<scf::ForOp>(loc, lower, upper, one, fields); 8070633a8dSAart Bik for (unsigned i = 0, e = fields.size(); i < e; i++) 8170633a8dSAart Bik fields[i] = forOp.getRegionIterArg(i); 8270633a8dSAart Bik builder.setInsertionPointToStart(forOp.getBody()); 8370633a8dSAart Bik return forOp; 8470633a8dSAart Bik } 8570633a8dSAart Bik 86bc61122aSAart Bik /// Creates a push back operation. 8770633a8dSAart Bik static void createPushback(OpBuilder &builder, Location loc, 88988733c6SPeiming Liu MutSparseTensorDescriptor desc, 89f708a549Swren romano SparseTensorFieldKind kind, std::optional<Level> lvl, 90f708a549Swren romano Value value, Value repeat = Value()) { 91f708a549Swren romano Type etp = desc.getMemRefElementType(kind, lvl); 92f708a549Swren romano Value field = desc.getMemRefField(kind, lvl); 93988733c6SPeiming Liu StorageSpecifierKind specFieldKind = toSpecifierKind(kind); 946607fdf7SAart Bik 95988733c6SPeiming Liu auto pushBackOp = builder.create<PushBackOp>( 96f708a549Swren romano loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field, 9744ff23d5SPeiming Liu genCast(builder, loc, value, etp), repeat); 98191c43f6SPeiming Liu 99f708a549Swren romano desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer()); 100f708a549Swren romano desc.setSpecifierField(builder, loc, specFieldKind, lvl, 101988733c6SPeiming Liu pushBackOp.getNewSize()); 1023ae98fd2SAart Bik } 1033ae98fd2SAart Bik 10470633a8dSAart Bik /// Generates code that allocates a sparse storage scheme for given rank. 10570633a8dSAart Bik static void allocSchemeForRank(OpBuilder &builder, Location loc, 106f708a549Swren romano MutSparseTensorDescriptor desc, Level startLvl) { 107f708a549Swren romano const SparseTensorType stt(desc.getRankedTensorType()); 10870633a8dSAart Bik Value linear = constantIndex(builder, loc, 1); 109f708a549Swren romano const Level lvlRank = stt.getLvlRank(); 110160d483bSAart Bik for (Level lvl = startLvl; lvl < lvlRank; lvl++) { 1111dd387e1SAart Bik const auto lt = stt.getLvlType(lvl); 1121dd387e1SAart Bik if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { 11384cd51bbSwren romano // Append linear x positions, initialized to zero. Since each compressed 11470633a8dSAart Bik // dimension initially already has a single zero entry, this maintains 115160d483bSAart Bik // the desired "linear + 1" length property at all times. For loose 116160d483bSAart Bik // compression, we multiply linear by two in order to append both the 117160d483bSAart Bik // lo/hi positions. 11884cd51bbSwren romano Value posZero = constantZero(builder, loc, stt.getPosType()); 1191dd387e1SAart Bik if (isLooseCompressedLT(lt)) { 120160d483bSAart Bik Value two = constantIndex(builder, loc, 2); 121160d483bSAart Bik linear = builder.create<arith::MulIOp>(loc, linear, two); 122ef99c27dSMehdi Amini } 123160d483bSAart Bik createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl, 124160d483bSAart Bik /*value=*/posZero, /*repeat=*/linear); 125160d483bSAart Bik return; 126e5924d64SYinying Li } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) { 12770633a8dSAart Bik return; // nothing to do 128dd33481fSPeiming Liu } 129dd33481fSPeiming Liu // Keep compounding the size, but nothing needs to be initialized 13070633a8dSAart Bik // at this level. We will eventually reach a compressed level or 13170633a8dSAart Bik // otherwise the values array for the from-here "all-dense" case. 1321dd387e1SAart Bik assert(isDenseLT(lt)); 133160d483bSAart Bik Value size = desc.getLvlSize(builder, loc, lvl); 13470633a8dSAart Bik linear = builder.create<arith::MulIOp>(loc, linear, size); 13570633a8dSAart Bik } 13670633a8dSAart Bik // Reached values array so prepare for an insertion. 137f708a549Swren romano Value valZero = constantZero(builder, loc, stt.getElementType()); 138988733c6SPeiming Liu createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef, 139160d483bSAart Bik std::nullopt, /*value=*/valZero, /*repeat=*/linear); 14070633a8dSAart Bik } 14170633a8dSAart Bik 14280b08b68SAart Bik /// Creates allocation operation. 143191c43f6SPeiming Liu static Value createAllocation(OpBuilder &builder, Location loc, 144191c43f6SPeiming Liu MemRefType memRefType, Value sz, 145191c43f6SPeiming Liu bool enableInit) { 146191c43f6SPeiming Liu Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz); 147191c43f6SPeiming Liu Type elemType = memRefType.getElementType(); 1487276b643Sbixia1 if (enableInit) { 149ea4be70cSbixia1 Value fillValue = constantZero(builder, loc, elemType); 1507276b643Sbixia1 builder.create<linalg::FillOp>(loc, fillValue, buffer); 1517276b643Sbixia1 } 1527276b643Sbixia1 return buffer; 1530c7abd39SAart Bik } 1540c7abd39SAart Bik 15583cf0dc9SAart Bik /// Creates the dim sizes array, filling in from dynamic sizes. 15683cf0dc9SAart Bik static void createDimSizes(OpBuilder &builder, Location loc, 15783cf0dc9SAart Bik SparseTensorType stt, ValueRange dynSizes, 15883cf0dc9SAart Bik /*out*/ SmallVectorImpl<Value> &dimSizesValues) { 15983cf0dc9SAart Bik const Dimension dimRank = stt.getDimRank(); 16083cf0dc9SAart Bik dimSizesValues.clear(); 16183cf0dc9SAart Bik dimSizesValues.reserve(dimRank); 16283cf0dc9SAart Bik unsigned i = 0; 16383cf0dc9SAart Bik for (const Size sz : stt.getDimShape()) 16483cf0dc9SAart Bik dimSizesValues.push_back(ShapedType::isDynamic(sz) 16583cf0dc9SAart Bik ? dynSizes[i++] 16683cf0dc9SAart Bik : constantIndex(builder, loc, sz)); 16783cf0dc9SAart Bik } 16883cf0dc9SAart Bik 1696607fdf7SAart Bik /// Creates allocation for each field in sparse tensor type. Note that 170160d483bSAart Bik /// for all dynamic memrefs in the sparse tensor stroage layout, the 171160d483bSAart Bik /// memory size is really the capacity of the "vector", while the actual 172160d483bSAart Bik /// size resides in the sizes array. 173f708a549Swren romano static void createAllocFields(OpBuilder &builder, Location loc, 17483cf0dc9SAart Bik SparseTensorType stt, bool enableInit, 17583cf0dc9SAart Bik Value sizeHint, 17683cf0dc9SAart Bik SmallVectorImpl<Value> &lvlSizesValues, 17783cf0dc9SAart Bik /*out*/ SmallVectorImpl<Value> &fields) { 17883cf0dc9SAart Bik Level lvlRank = stt.getLvlRank(); 179e2e6e7a6SAart Bik // Set up some heuristic sizes. We try to set the initial 180e2e6e7a6SAart Bik // size based on available information. Otherwise we just 181e2e6e7a6SAart Bik // initialize a few elements to start the reallocation chain. 182e2e6e7a6SAart Bik // TODO: refine this 18384cd51bbSwren romano Value posHeuristic, crdHeuristic, valHeuristic; 184f708a549Swren romano if (stt.isAllDense()) { 18583cf0dc9SAart Bik valHeuristic = lvlSizesValues[0]; 18683cf0dc9SAart Bik for (Level lvl = 1; lvl < lvlRank; lvl++) 18783cf0dc9SAart Bik valHeuristic = 18883cf0dc9SAart Bik builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]); 189e2e6e7a6SAart Bik } else if (sizeHint) { 1905248a987SPeiming Liu if (stt.getAoSCOOStart() == 0) { 19184cd51bbSwren romano posHeuristic = constantIndex(builder, loc, 2); 19284cd51bbSwren romano crdHeuristic = builder.create<arith::MulIOp>( 19383cf0dc9SAart Bik loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS 19483cf0dc9SAart Bik } else if (lvlRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) { 19584cd51bbSwren romano posHeuristic = builder.create<arith::AddIOp>( 1963bd82f30SAart Bik loc, sizeHint, constantIndex(builder, loc, 1)); 19784cd51bbSwren romano crdHeuristic = sizeHint; 198e2e6e7a6SAart Bik } else { 19984cd51bbSwren romano posHeuristic = crdHeuristic = constantIndex(builder, loc, 16); 200b78b5473Sbixia1 } 201e2e6e7a6SAart Bik valHeuristic = sizeHint; 202e2e6e7a6SAart Bik } else { 20384cd51bbSwren romano posHeuristic = crdHeuristic = valHeuristic = 204e2e6e7a6SAart Bik constantIndex(builder, loc, 16); 205e2e6e7a6SAart Bik } 206160d483bSAart Bik // Initializes all fields. An initial storage specifier and allocated 207160d483bSAart Bik // positions/coordinates/values memrefs (with heuristic capacity). 208191c43f6SPeiming Liu foreachFieldAndTypeInSparseTensor( 209f708a549Swren romano stt, 21084cd51bbSwren romano [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic, 211f708a549Swren romano enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, 2121944c4f7SAart Bik Level /*lvl*/, LevelType /*lt*/) -> bool { 213191c43f6SPeiming Liu assert(fields.size() == fIdx); 214191c43f6SPeiming Liu Value field; 215191c43f6SPeiming Liu switch (fKind) { 216988733c6SPeiming Liu case SparseTensorFieldKind::StorageSpec: 217f708a549Swren romano field = SparseTensorSpecifier::getInitValue(builder, loc, stt); 218191c43f6SPeiming Liu break; 21984cd51bbSwren romano case SparseTensorFieldKind::PosMemRef: 220160d483bSAart Bik field = createAllocation(builder, loc, cast<MemRefType>(fType), 221160d483bSAart Bik posHeuristic, enableInit); 222160d483bSAart Bik break; 22384cd51bbSwren romano case SparseTensorFieldKind::CrdMemRef: 224160d483bSAart Bik field = createAllocation(builder, loc, cast<MemRefType>(fType), 225160d483bSAart Bik crdHeuristic, enableInit); 226160d483bSAart Bik break; 227191c43f6SPeiming Liu case SparseTensorFieldKind::ValMemRef: 228160d483bSAart Bik field = createAllocation(builder, loc, cast<MemRefType>(fType), 229160d483bSAart Bik valHeuristic, enableInit); 230191c43f6SPeiming Liu break; 231191c43f6SPeiming Liu } 232191c43f6SPeiming Liu assert(field); 233191c43f6SPeiming Liu fields.push_back(field); 234191c43f6SPeiming Liu // Returns true to continue the iteration. 235191c43f6SPeiming Liu return true; 236191c43f6SPeiming Liu }); 237160d483bSAart Bik // Initialize the storage scheme to an empty tensor. Sets the lvlSizes 238160d483bSAart Bik // and gives all position fields an initial zero entry, so that it is 239160d483bSAart Bik // easier to maintain the "linear + 1" length property. 240f708a549Swren romano MutSparseTensorDescriptor desc(stt, fields); 24184cd51bbSwren romano Value posZero = constantZero(builder, loc, stt.getPosType()); 242160d483bSAart Bik for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { 24383cf0dc9SAart Bik desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]); 2441dd387e1SAart Bik const auto lt = stt.getLvlType(lvl); 2451dd387e1SAart Bik if (isCompressedLT(lt) || isLooseCompressedLT(lt)) 246160d483bSAart Bik createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl, 247160d483bSAart Bik /*value=*/posZero); 248988733c6SPeiming Liu } 249191c43f6SPeiming Liu allocSchemeForRank(builder, loc, desc, /*rank=*/0); 2500c7abd39SAart Bik } 2510c7abd39SAart Bik 25270633a8dSAart Bik /// Helper method that generates block specific to compressed case: 25370633a8dSAart Bik /// 25484cd51bbSwren romano /// // given: parentPos = posCursor[lvl-1] 25584cd51bbSwren romano /// pstart = desc.positions[lvl][parentPos] 25684cd51bbSwren romano /// pstop = desc.positions[lvl][parentPos+1] 25784cd51bbSwren romano /// plast = pstop - 1 25884cd51bbSwren romano /// msz = desc.coordinates[lvl].size() 25984cd51bbSwren romano /// if (pstart < pstop) { 26084cd51bbSwren romano /// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl]) 26170633a8dSAart Bik /// } else { // first insertion 26284cd51bbSwren romano /// isPresent = false 26384cd51bbSwren romano /// desc.positions[lvl][parentPos] = msz 26470633a8dSAart Bik /// } 26584cd51bbSwren romano /// if (isPresent) { // coordinate is already present 26684cd51bbSwren romano /// pnext = plast 26770633a8dSAart Bik /// } else { 26884cd51bbSwren romano /// desc.coordinates[lvl].push_back(lvlCoords[lvl]) 26984cd51bbSwren romano /// desc.positions[lvl][parentPos+1] = msz+1 27084cd51bbSwren romano /// pnext = msz 27184cd51bbSwren romano /// <prepare level lvl+1> 27270633a8dSAart Bik /// } 27384cd51bbSwren romano /// posCursor[lvl] = pnext 27470633a8dSAart Bik static Value genCompressed(OpBuilder &builder, Location loc, 27584cd51bbSwren romano MutSparseTensorDescriptor desc, ValueRange lvlCoords, 27684cd51bbSwren romano Value /*unused*/, Value parentPos, Level lvl) { 277f708a549Swren romano const SparseTensorType stt(desc.getRankedTensorType()); 278f708a549Swren romano const Level lvlRank = stt.getLvlRank(); 279f708a549Swren romano assert(lvl < lvlRank && "Level is out of bounds"); 28084cd51bbSwren romano assert(lvlCoords.size() == static_cast<size_t>(lvlRank) && 281f708a549Swren romano "Level-rank mismatch"); 2820e1708ffSAart Bik SmallVector<Type> types; 2833986c869SAart Bik Type indexType = builder.getIndexType(); 28470633a8dSAart Bik Type boolType = builder.getIntegerType(1); 28584cd51bbSwren romano unsigned crdFidx; 28684cd51bbSwren romano unsigned crdStride; 28784cd51bbSwren romano std::tie(crdFidx, crdStride) = desc.getCrdMemRefIndexAndStride(lvl); 28884cd51bbSwren romano const Value one = constantIndex(builder, loc, 1); 28984cd51bbSwren romano const Value pp1 = builder.create<arith::AddIOp>(loc, parentPos, one); 29084cd51bbSwren romano const Value positionsAtLvl = desc.getPosMemRef(lvl); 29184cd51bbSwren romano const Value pstart = genLoad(builder, loc, positionsAtLvl, parentPos); 29284cd51bbSwren romano const Value pstop = genLoad(builder, loc, positionsAtLvl, pp1); 29384cd51bbSwren romano const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl); 29484cd51bbSwren romano const Value crdStrideC = 29584cd51bbSwren romano crdStride > 1 ? constantIndex(builder, loc, crdStride) : Value(); 29684cd51bbSwren romano const Value msz = 29784cd51bbSwren romano crdStrideC ? builder.create<arith::DivUIOp>(loc, crdMsz, crdStrideC) 29884cd51bbSwren romano : crdMsz; 29984cd51bbSwren romano const Value plast = builder.create<arith::SubIOp>( 30084cd51bbSwren romano loc, genCast(builder, loc, pstop, indexType), one); 30170633a8dSAart Bik // Conditional expression. 30284cd51bbSwren romano Value lt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 30384cd51bbSwren romano pstart, pstop); 30470633a8dSAart Bik types.push_back(boolType); 30570633a8dSAart Bik scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true); 30670633a8dSAart Bik types.pop_back(); 30770633a8dSAart Bik builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); 30884cd51bbSwren romano Value crd = 30984cd51bbSwren romano genLoad(builder, loc, desc.getMemRefField(crdFidx), 31084cd51bbSwren romano crdStrideC ? builder.create<arith::MulIOp>(loc, plast, crdStrideC) 31184cd51bbSwren romano : plast); 31244ff23d5SPeiming Liu Value eq = builder.create<arith::CmpIOp>( 31344ff23d5SPeiming Liu loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType), 31484cd51bbSwren romano lvlCoords[lvl]); 31570633a8dSAart Bik builder.create<scf::YieldOp>(loc, eq); 31670633a8dSAart Bik builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); 317f708a549Swren romano if (lvl > 0) 31884cd51bbSwren romano genStore(builder, loc, msz, positionsAtLvl, parentPos); 31970633a8dSAart Bik builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false)); 32070633a8dSAart Bik builder.setInsertionPointAfter(ifOp1); 321191c43f6SPeiming Liu // If present construct. Note that for a non-unique dimension level, we 322191c43f6SPeiming Liu // simply set the condition to false and rely on CSE/DCE to clean up the IR. 32370633a8dSAart Bik // 32470633a8dSAart Bik // TODO: generate less temporary IR? 32570633a8dSAart Bik // 326191c43f6SPeiming Liu for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) 327191c43f6SPeiming Liu types.push_back(desc.getField(i).getType()); 32870633a8dSAart Bik types.push_back(indexType); 329f708a549Swren romano const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0) 330f708a549Swren romano : constantI1(builder, loc, false); 33170633a8dSAart Bik scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true); 33284cd51bbSwren romano // If present (fields unaffected, update pnext to plast). 33370633a8dSAart Bik builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 334191c43f6SPeiming Liu 335191c43f6SPeiming Liu // FIXME: This does not looks like a clean way, but probably the most 336191c43f6SPeiming Liu // efficient way. 33784cd51bbSwren romano desc.getFields().push_back(plast); 338191c43f6SPeiming Liu builder.create<scf::YieldOp>(loc, desc.getFields()); 339191c43f6SPeiming Liu desc.getFields().pop_back(); 340191c43f6SPeiming Liu 34184cd51bbSwren romano // If !present (changes fields, update pnext). 34270633a8dSAart Bik builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 34370633a8dSAart Bik Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one); 34484cd51bbSwren romano genStore(builder, loc, mszp1, positionsAtLvl, pp1); 34584cd51bbSwren romano createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl, 346160d483bSAart Bik /*value=*/lvlCoords[lvl]); 34784cd51bbSwren romano // Prepare the next level "as needed". 348f708a549Swren romano if ((lvl + 1) < lvlRank) 349f708a549Swren romano allocSchemeForRank(builder, loc, desc, lvl + 1); 350191c43f6SPeiming Liu 351191c43f6SPeiming Liu desc.getFields().push_back(msz); 352191c43f6SPeiming Liu builder.create<scf::YieldOp>(loc, desc.getFields()); 353191c43f6SPeiming Liu desc.getFields().pop_back(); 354191c43f6SPeiming Liu 35570633a8dSAart Bik // Update fields and return next pos. 35670633a8dSAart Bik builder.setInsertionPointAfter(ifOp2); 35770633a8dSAart Bik unsigned o = 0; 358191c43f6SPeiming Liu for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) 359191c43f6SPeiming Liu desc.setField(i, ifOp2.getResult(o++)); 36070633a8dSAart Bik return ifOp2.getResult(o); 3613986c869SAart Bik } 3623986c869SAart Bik 363bc61122aSAart Bik /// Generates insertion finalization code. 364bc61122aSAart Bik static void genEndInsert(OpBuilder &builder, Location loc, 365bc61122aSAart Bik SparseTensorDescriptor desc) { 366bc61122aSAart Bik const SparseTensorType stt(desc.getRankedTensorType()); 367bc61122aSAart Bik const Level lvlRank = stt.getLvlRank(); 36883cf0dc9SAart Bik for (Level lvl = 0; lvl < lvlRank; lvl++) { 3691dd387e1SAart Bik const auto lt = stt.getLvlType(lvl); 3701dd387e1SAart Bik if (isCompressedLT(lt)) { 371bc61122aSAart Bik // Compressed dimensions need a position cleanup for all entries 372bc61122aSAart Bik // that were not visited during the insertion pass. 373bc61122aSAart Bik // 374bc61122aSAart Bik // TODO: avoid cleanup and keep compressed scheme consistent at all 375bc61122aSAart Bik // times? 376bc61122aSAart Bik // 37783cf0dc9SAart Bik if (lvl > 0) { 378bc61122aSAart Bik Type posType = stt.getPosType(); 37983cf0dc9SAart Bik Value posMemRef = desc.getPosMemRef(lvl); 38083cf0dc9SAart Bik Value hi = desc.getPosMemSize(builder, loc, lvl); 381bc61122aSAart Bik Value zero = constantIndex(builder, loc, 0); 382bc61122aSAart Bik Value one = constantIndex(builder, loc, 1); 383bc61122aSAart Bik // Vector of only one, but needed by createFor's prototype. 384bc61122aSAart Bik SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)}; 385bc61122aSAart Bik scf::ForOp loop = createFor(builder, loc, hi, inits, one); 386bc61122aSAart Bik Value i = loop.getInductionVar(); 387bc61122aSAart Bik Value oldv = loop.getRegionIterArg(0); 388bc61122aSAart Bik Value newv = genLoad(builder, loc, posMemRef, i); 389bc61122aSAart Bik Value posZero = constantZero(builder, loc, posType); 390bc61122aSAart Bik Value cond = builder.create<arith::CmpIOp>( 391bc61122aSAart Bik loc, arith::CmpIPredicate::eq, newv, posZero); 392bc61122aSAart Bik scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType), 393bc61122aSAart Bik cond, /*else*/ true); 394bc61122aSAart Bik builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 395bc61122aSAart Bik genStore(builder, loc, oldv, posMemRef, i); 396bc61122aSAart Bik builder.create<scf::YieldOp>(loc, oldv); 397bc61122aSAart Bik builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 398bc61122aSAart Bik builder.create<scf::YieldOp>(loc, newv); 399bc61122aSAart Bik builder.setInsertionPointAfter(ifOp); 400bc61122aSAart Bik builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); 401bc61122aSAart Bik builder.setInsertionPointAfter(loop); 402bc61122aSAart Bik } 403bc61122aSAart Bik } else { 4041dd387e1SAart Bik assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) || 405e5924d64SYinying Li isNOutOfMLT(lt)); 406bc61122aSAart Bik } 407bc61122aSAart Bik } 408bc61122aSAart Bik } 409bc61122aSAart Bik 410bc61122aSAart Bik /// Generates a subview into the sizes. 411bc61122aSAart Bik static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, 412bc61122aSAart Bik Value sz) { 413c4e5a8a4SAart Bik auto memTp = llvm::cast<MemRefType>(mem.getType()); 414c4e5a8a4SAart Bik // For higher-dimensional memrefs, we assume that the innermost 415c4e5a8a4SAart Bik // dimension is always of the right size. 416c4e5a8a4SAart Bik // TODO: generate complex truncating view here too? 417c4e5a8a4SAart Bik if (memTp.getRank() > 1) 418c4e5a8a4SAart Bik return mem; 419c4e5a8a4SAart Bik // Truncate linear memrefs to given size. 420bc61122aSAart Bik return builder 421bc61122aSAart Bik .create<memref::SubViewOp>( 422c4e5a8a4SAart Bik loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()), 423c4e5a8a4SAart Bik mem, ValueRange{}, ValueRange{sz}, ValueRange{}, 424bc61122aSAart Bik ArrayRef<int64_t>{0}, // static offset 425bc61122aSAart Bik ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size 426bc61122aSAart Bik ArrayRef<int64_t>{1}) // static stride 427bc61122aSAart Bik .getResult(); 428bc61122aSAart Bik } 429bc61122aSAart Bik 430bc61122aSAart Bik /// Creates the reassociation array. 43152b69aa3SPeiming Liu static SmallVector<ReassociationIndices> 43252b69aa3SPeiming Liu getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) { 43352b69aa3SPeiming Liu SmallVector<ReassociationIndices> ret(batchLvls + 1, {}); 43452b69aa3SPeiming Liu // Create reassociation in the form: 43552b69aa3SPeiming Liu // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank} 43652b69aa3SPeiming Liu for (unsigned i = 0; i < batchLvls; i++) 43752b69aa3SPeiming Liu ret[i].push_back(i); 43852b69aa3SPeiming Liu 43952b69aa3SPeiming Liu for (int i = batchLvls, e = srcTp.getRank(); i < e; i++) 44052b69aa3SPeiming Liu ret.back().push_back(i); 44152b69aa3SPeiming Liu 44252b69aa3SPeiming Liu return ret; 443bc61122aSAart Bik } 444bc61122aSAart Bik 445bc61122aSAart Bik //===----------------------------------------------------------------------===// 446bc61122aSAart Bik // Codegen rules. 447bc61122aSAart Bik //===----------------------------------------------------------------------===// 448bc61122aSAart Bik 449bc61122aSAart Bik namespace { 450bc61122aSAart Bik 451ad469385SPeiming Liu /// Helper class to help lowering sparse_tensor.insert operation. 452ad469385SPeiming Liu class SparseInsertGenerator 453ad469385SPeiming Liu : public FuncCallOrInlineGenerator<SparseInsertGenerator> { 454ad469385SPeiming Liu public: 455ad469385SPeiming Liu SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params, 456ad469385SPeiming Liu bool genCall) 457ad469385SPeiming Liu : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){}; 458ad469385SPeiming Liu 45970633a8dSAart Bik /// Generates code along an insertion path without the need for a "cursor". 46070633a8dSAart Bik /// This current insertion strategy comes at the expense of some testing 46170633a8dSAart Bik /// overhead for each insertion. The strategy will be optimized later for 46270633a8dSAart Bik /// common insertion patterns. The current insertion strategy also assumes 46370633a8dSAart Bik /// insertions occur in "a reasonable order" that enables building the 46470633a8dSAart Bik /// storage scheme in an appending/inserting kind of fashion (i.e. no 46570633a8dSAart Bik /// in-between insertions that need data movement). The implementation 46670633a8dSAart Bik /// relies on CSE/DCE to clean up all bookkeeping that is not needed. 46770633a8dSAart Bik /// 46870633a8dSAart Bik /// TODO: better unord/not-unique; also generalize, optimize, specialize! 469ad469385SPeiming Liu SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args, 470ad469385SPeiming Liu OpBuilder &builder, Location loc) { 47168f58812STres Popp const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp)); 472f708a549Swren romano const Level lvlRank = stt.getLvlRank(); 47384cd51bbSwren romano // Extract fields and coordinates from args. 474f708a549Swren romano SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1)); 475ad469385SPeiming Liu MutSparseTensorDescriptor desc(stt, fields); 476962484aeSwren romano const SmallVector<Value> coords = 477f708a549Swren romano llvm::to_vector(args.take_back(lvlRank + 1).drop_back()); 4782aceaddaSbixia1 Value value = args.back(); 47984cd51bbSwren romano Value parentPos = constantZero(builder, loc, builder.getIndexType()); 480f708a549Swren romano // Generate code for every level. 481160d483bSAart Bik for (Level lvl = 0; lvl < lvlRank; lvl++) { 4821dd387e1SAart Bik const auto lt = stt.getLvlType(lvl); 4831dd387e1SAart Bik if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { 48470633a8dSAart Bik // Create: 48570633a8dSAart Bik // if (!present) { 486160d483bSAart Bik // coordinates[lvl].push_back(coords[lvl]) 487160d483bSAart Bik // <update positions and prepare level lvl + 1> 48870633a8dSAart Bik // } 489160d483bSAart Bik // positions[lvl] = coordinates.size() - 1 490160d483bSAart Bik // <insert @ positions[lvl] at next level lvl + 1> 4911dd387e1SAart Bik if (isLooseCompressedLT(lt)) { 492160d483bSAart Bik Value two = constantIndex(builder, loc, 2); 493160d483bSAart Bik parentPos = builder.create<arith::MulIOp>(loc, parentPos, two); 494160d483bSAart Bik } 49584cd51bbSwren romano parentPos = 496160d483bSAart Bik genCompressed(builder, loc, desc, coords, value, parentPos, lvl); 497e5924d64SYinying Li } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) { 49870633a8dSAart Bik // Create: 499160d483bSAart Bik // coordinates[lvl].push_back(coords[lvl]) 500160d483bSAart Bik // positions[lvl] = positions[lvl-1] 501160d483bSAart Bik // <insert @ positions[lvl] at next level lvl + 1> 502160d483bSAart Bik createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, 503160d483bSAart Bik lvl, /*value=*/coords[lvl]); 50470633a8dSAart Bik } else { 5051dd387e1SAart Bik assert(isDenseLT(lt)); 50670633a8dSAart Bik // Construct the new position as: 507160d483bSAart Bik // positions[lvl] = size * positions[lvl-1] + coords[lvl] 508160d483bSAart Bik // <insert @ positions[lvl] at next level lvl + 1> 509160d483bSAart Bik Value size = desc.getLvlSize(builder, loc, lvl); 51084cd51bbSwren romano Value mult = builder.create<arith::MulIOp>(loc, size, parentPos); 511160d483bSAart Bik parentPos = builder.create<arith::AddIOp>(loc, mult, coords[lvl]); 51270633a8dSAart Bik } 51370633a8dSAart Bik } 51470633a8dSAart Bik // Reached the actual value append/insert. 515f708a549Swren romano if (!stt.isDenseLvl(lvlRank - 1)) 516988733c6SPeiming Liu createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef, 517988733c6SPeiming Liu std::nullopt, value); 51870633a8dSAart Bik else 51984cd51bbSwren romano genStore(builder, loc, value, desc.getValMemRef(), parentPos); 520ad469385SPeiming Liu return fields; 5212aceaddaSbixia1 } 5222aceaddaSbixia1 523ad469385SPeiming Liu std::string getMangledFuncName() { 5242aceaddaSbixia1 // The mangled name of the function has this format: 5251dd387e1SAart Bik // <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth> 526ad469385SPeiming Liu constexpr const char kInsertFuncNamePrefix[] = "_insert_"; 52768f58812STres Popp const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp)); 5282aceaddaSbixia1 SmallString<32> nameBuffer; 5292aceaddaSbixia1 llvm::raw_svector_ostream nameOstream(nameBuffer); 530ad469385SPeiming Liu nameOstream << kInsertFuncNamePrefix; 531f708a549Swren romano const Level lvlRank = stt.getLvlRank(); 5326280e231SYinying Li for (Level l = 0; l < lvlRank; l++) { 5336280e231SYinying Li std::string lvlType = toMLIRString(stt.getLvlType(l)); 5346280e231SYinying Li // Replace/remove punctuations in level properties. 5356280e231SYinying Li std::replace_if( 5366280e231SYinying Li lvlType.begin(), lvlType.end(), 5376280e231SYinying Li [](char c) { return c == '(' || c == ','; }, '_'); 5386461a824SKazu Hirata llvm::erase_if(lvlType, [](char c) { return c == ')' || c == ' '; }); 5396280e231SYinying Li nameOstream << lvlType << "_"; 5406280e231SYinying Li } 5412aceaddaSbixia1 // Static dim sizes are used in the generated code while dynamic sizes are 5422aceaddaSbixia1 // loaded from the dimSizes buffer. This is the reason for adding the shape 5432aceaddaSbixia1 // to the function name. 544160d483bSAart Bik for (const auto sz : stt.getDimShape()) 545160d483bSAart Bik nameOstream << sz << "_"; 5462aceaddaSbixia1 // Permutation information is also used in generating insertion. 547f708a549Swren romano if (!stt.isIdentity()) 54876647fceSwren romano nameOstream << stt.getDimToLvl() << "_"; 549f708a549Swren romano nameOstream << stt.getElementType() << "_"; 55084cd51bbSwren romano nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth(); 551ad469385SPeiming Liu return nameOstream.str().str(); 5522aceaddaSbixia1 } 5532aceaddaSbixia1 554ad469385SPeiming Liu private: 555ad469385SPeiming Liu TensorType rtp; 556ad469385SPeiming Liu }; 5579f596a7cSAart Bik 558edca72f5SPeiming Liu /// Sparse tensor storage conversion rule for returns. 55986b22d31SAart Bik class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 56086b22d31SAart Bik public: 56186b22d31SAart Bik using OpConversionPattern::OpConversionPattern; 56286b22d31SAart Bik LogicalResult 563*9df63b26SMatthias Springer matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor, 56486b22d31SAart Bik ConversionPatternRewriter &rewriter) const override { 565edca72f5SPeiming Liu // Create a return with the flattened value extracted from sparse tensors. 566*9df63b26SMatthias Springer rewriter.replaceOpWithNewOp<func::ReturnOp>( 567*9df63b26SMatthias Springer op, flattenValues(adaptor.getOperands())); 568edca72f5SPeiming Liu return success(); 569edca72f5SPeiming Liu } 570edca72f5SPeiming Liu }; 571edca72f5SPeiming Liu 572edca72f5SPeiming Liu /// Sparse tensor storage conversion rule for calls. 573edca72f5SPeiming Liu class SparseCallConverter : public OpConversionPattern<func::CallOp> { 574edca72f5SPeiming Liu public: 575edca72f5SPeiming Liu // The default CallOp converter can not handle 1:N type conversion. 576edca72f5SPeiming Liu using OpConversionPattern::OpConversionPattern; 577edca72f5SPeiming Liu LogicalResult 578*9df63b26SMatthias Springer matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor, 579edca72f5SPeiming Liu ConversionPatternRewriter &rewriter) const override { 580edca72f5SPeiming Liu Location loc = op.getLoc(); 581edca72f5SPeiming Liu // In case of: 582edca72f5SPeiming Liu // sparse_tensor, f, sparse_tensor = call @foo(...) 583edca72f5SPeiming Liu // ==> 584edca72f5SPeiming Liu // memref..., f, memref = call @foo(...) replace with 585edca72f5SPeiming Liu // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor 5860e1708ffSAart Bik SmallVector<Type> finalRetTy; 587edca72f5SPeiming Liu if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy))) 588edca72f5SPeiming Liu return failure(); 589edca72f5SPeiming Liu 590be556ee1SYinying Li // (1) Generates new call with flattened return value. 591*9df63b26SMatthias Springer auto newCall = rewriter.create<func::CallOp>( 592*9df63b26SMatthias Springer loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands())); 593aed43562SMatthias Springer // (2) Gather sparse tensor returns. 594aed43562SMatthias Springer SmallVector<SmallVector<Value>> packedResultVals; 595be556ee1SYinying Li // Tracks the offset of current return value (of the original call) 596edca72f5SPeiming Liu // relative to the new call (after sparse tensor flattening); 597edca72f5SPeiming Liu unsigned retOffset = 0; 598edca72f5SPeiming Liu // Temporal buffer to hold the flattened list of type for 599edca72f5SPeiming Liu // a sparse tensor. 6000e1708ffSAart Bik SmallVector<Type> sparseFlat; 601edca72f5SPeiming Liu for (auto ret : op.getResults()) { 602edca72f5SPeiming Liu assert(retOffset < newCall.getNumResults()); 603edca72f5SPeiming Liu auto retType = ret.getType(); 604edca72f5SPeiming Liu if (failed(typeConverter->convertType(retType, sparseFlat))) 605edca72f5SPeiming Liu llvm_unreachable("Failed to convert type in sparse tensor codegen"); 606edca72f5SPeiming Liu 607edca72f5SPeiming Liu // Converted types can not be empty when the type conversion succeed. 608edca72f5SPeiming Liu assert(!sparseFlat.empty()); 609edca72f5SPeiming Liu if (sparseFlat.size() > 1) { 610edca72f5SPeiming Liu auto flatSize = sparseFlat.size(); 611aed43562SMatthias Springer packedResultVals.emplace_back(); 612aed43562SMatthias Springer llvm::append_range(packedResultVals.back(), 613aed43562SMatthias Springer newCall.getResults().slice(retOffset, flatSize)); 614edca72f5SPeiming Liu retOffset += flatSize; 615edca72f5SPeiming Liu } else { 616edca72f5SPeiming Liu // If this is an 1:1 conversion, no need for casting. 617aed43562SMatthias Springer packedResultVals.emplace_back(); 618aed43562SMatthias Springer packedResultVals.back().push_back(newCall.getResult(retOffset)); 619edca72f5SPeiming Liu retOffset++; 620edca72f5SPeiming Liu } 621edca72f5SPeiming Liu sparseFlat.clear(); 622edca72f5SPeiming Liu } 623edca72f5SPeiming Liu 624aed43562SMatthias Springer assert(packedResultVals.size() == op.getNumResults()); 625aed43562SMatthias Springer rewriter.replaceOpWithMultiple( 626aed43562SMatthias Springer op, llvm::to_vector_of<ValueRange>(packedResultVals)); 62786b22d31SAart Bik return success(); 62886b22d31SAart Bik } 62986b22d31SAart Bik }; 63086b22d31SAart Bik 631c780352dSPeiming Liu /// Sparse codegen rule for level accesses. 632c780352dSPeiming Liu class SparseLvlOpConverter : public OpConversionPattern<LvlOp> { 6331be09496SAart Bik public: 6341be09496SAart Bik using OpConversionPattern::OpConversionPattern; 6351be09496SAart Bik LogicalResult 636*9df63b26SMatthias Springer matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor, 6371be09496SAart Bik ConversionPatternRewriter &rewriter) const override { 638c780352dSPeiming Liu std::optional<int64_t> lvl = op.getConstantLvlIndex(); 639204234a6SMatthias Springer RankedTensorType srcType = op.getSource().getType(); 640204234a6SMatthias Springer if (!lvl || !getSparseTensorEncoding(srcType)) 6411be09496SAart Bik return failure(); 642191c43f6SPeiming Liu 643204234a6SMatthias Springer auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType); 644c780352dSPeiming Liu auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl); 645191c43f6SPeiming Liu 64683a50839SPeiming Liu rewriter.replaceOp(op, sz); 6473ae98fd2SAart Bik return success(); 6483ae98fd2SAart Bik } 6493ae98fd2SAart Bik }; 6503ae98fd2SAart Bik 651dda3dc5eSPeiming Liu // TODO: use a new SortCOO operation here instead of reusing convert op. 652f248d0b2SPeiming Liu struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> { 653dda3dc5eSPeiming Liu using OpConversionPattern::OpConversionPattern; 654dda3dc5eSPeiming Liu LogicalResult 655*9df63b26SMatthias Springer matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor, 656dda3dc5eSPeiming Liu ConversionPatternRewriter &rewriter) const override { 657dda3dc5eSPeiming Liu Location loc = op.getLoc(); 658dda3dc5eSPeiming Liu MLIRContext *ctx = op.getContext(); 659dda3dc5eSPeiming Liu 660f248d0b2SPeiming Liu SparseTensorType srcStt = getSparseTensorType(op.getInputCoo()); 661f248d0b2SPeiming Liu SparseTensorType dstStt = getSparseTensorType(op.getResultCoo()); 662dda3dc5eSPeiming Liu 663f248d0b2SPeiming Liu // Should have been verified. 664dda3dc5eSPeiming Liu assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() && 6655b729503SAart Bik dstStt.isCOOType() && srcStt.isCOOType()); 666dda3dc5eSPeiming Liu assert(dstStt.hasSameDimToLvl(srcStt)); 667dda3dc5eSPeiming Liu 668dda3dc5eSPeiming Liu // We don't need a mutable descriptor here as we perform sorting in-place. 669204234a6SMatthias Springer auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(), 670204234a6SMatthias Springer op.getInputCoo().getType()); 671204234a6SMatthias Springer auto nnz = desc.getValMemSize(rewriter, op.getLoc()); 672dda3dc5eSPeiming Liu auto crd = desc.getAOSMemRef(); 673dda3dc5eSPeiming Liu auto val = desc.getValMemRef(); 674dda3dc5eSPeiming Liu 675dda3dc5eSPeiming Liu // Otherwise we need another data shuffle and a non-identity map. 676dda3dc5eSPeiming Liu assert(dstStt.hasSameDimToLvl(srcStt)); 677837a26f2SPeiming Liu (void)dstStt; // to silence warning when assertion is disabled 678837a26f2SPeiming Liu 679dda3dc5eSPeiming Liu auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx); 680dda3dc5eSPeiming Liu 681dda3dc5eSPeiming Liu rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id, 682f248d0b2SPeiming Liu rewriter.getIndexAttr(0), op.getAlgorithm()); 683dda3dc5eSPeiming Liu 684dda3dc5eSPeiming Liu // Since we do in-place sorting, the destinate tensor will have the same set 685dda3dc5eSPeiming Liu // of memrefs as the source tensor. 686*9df63b26SMatthias Springer rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()}); 687dda3dc5eSPeiming Liu return success(); 688dda3dc5eSPeiming Liu } 689dda3dc5eSPeiming Liu }; 690dda3dc5eSPeiming Liu 6916db397a8SPeiming Liu template <typename Op, StorageSpecifierKind kind> 6926db397a8SPeiming Liu class SparseSliceGetterOpConverter : public OpConversionPattern<Op> { 6936db397a8SPeiming Liu public: 6946db397a8SPeiming Liu using OpConversionPattern<Op>::OpConversionPattern; 695*9df63b26SMatthias Springer using typename OpConversionPattern<Op>::OneToNOpAdaptor; 696*9df63b26SMatthias Springer 6976db397a8SPeiming Liu LogicalResult 698*9df63b26SMatthias Springer matchAndRewrite(Op op, OneToNOpAdaptor adaptor, 6996db397a8SPeiming Liu ConversionPatternRewriter &rewriter) const override { 7006db397a8SPeiming Liu // Simply lowers to specifer.get <field> operation. 701204234a6SMatthias Springer auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(), 702204234a6SMatthias Springer op.getSlice().getType()); 7036db397a8SPeiming Liu auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind, 7046db397a8SPeiming Liu op.getDim().getZExtValue()); 7056db397a8SPeiming Liu 7066db397a8SPeiming Liu rewriter.replaceOp(op, v); 7076db397a8SPeiming Liu return success(); 7086db397a8SPeiming Liu } 7096db397a8SPeiming Liu }; 7106db397a8SPeiming Liu 711f27b806dSAart Bik /// Sparse codegen rule for trivial tensor casts. 712f27b806dSAart Bik class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 713f27b806dSAart Bik public: 714f27b806dSAart Bik using OpConversionPattern::OpConversionPattern; 715f27b806dSAart Bik LogicalResult 716*9df63b26SMatthias Springer matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor, 717f27b806dSAart Bik ConversionPatternRewriter &rewriter) const override { 718f27b806dSAart Bik // Only rewrite identically annotated source/dest. 719f27b806dSAart Bik auto encDst = getSparseTensorEncoding(op.getType()); 720f27b806dSAart Bik auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 721f27b806dSAart Bik if (!encDst || encDst != encSrc) 722f27b806dSAart Bik return failure(); 723*9df63b26SMatthias Springer rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); 724f27b806dSAart Bik return success(); 725f27b806dSAart Bik } 726f27b806dSAart Bik }; 727f27b806dSAart Bik 728ef222988SPeiming Liu class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> { 729ef222988SPeiming Liu public: 730ef222988SPeiming Liu using OpConversionPattern::OpConversionPattern; 731ef222988SPeiming Liu LogicalResult 732*9df63b26SMatthias Springer matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor, 733ef222988SPeiming Liu ConversionPatternRewriter &rewriter) const override { 734ef222988SPeiming Liu // Simply fold the operation. 735*9df63b26SMatthias Springer rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); 736ef222988SPeiming Liu return success(); 737ef222988SPeiming Liu } 738ef222988SPeiming Liu }; 739ef222988SPeiming Liu 740be556ee1SYinying Li /// Sparse codegen rule for the alloc operator. 7410c7abd39SAart Bik class SparseTensorAllocConverter 7420c7abd39SAart Bik : public OpConversionPattern<bufferization::AllocTensorOp> { 7430c7abd39SAart Bik public: 7440c7abd39SAart Bik using OpConversionPattern::OpConversionPattern; 745206fad0eSMatthias Springer SparseTensorAllocConverter(const TypeConverter &typeConverter, 746206fad0eSMatthias Springer MLIRContext *context, bool enableInit) 7477276b643Sbixia1 : OpConversionPattern(typeConverter, context), 7487276b643Sbixia1 enableBufferInitialization(enableInit) {} 749988733c6SPeiming Liu 7500c7abd39SAart Bik LogicalResult 751*9df63b26SMatthias Springer matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor, 7520c7abd39SAart Bik ConversionPatternRewriter &rewriter) const override { 753f708a549Swren romano const auto resType = getSparseTensorType(op); 754f708a549Swren romano if (!resType.hasEncoding()) 7550c7abd39SAart Bik return failure(); 75683cf0dc9SAart Bik 7572cc4b3d0SPeiming Liu Location loc = op.getLoc(); 75883cf0dc9SAart Bik // Deal with copy. 7597b86f7c5SPeiming Liu if (op.getCopy()) { 760204234a6SMatthias Springer auto desc = getDescriptorFromTensorTuple( 761204234a6SMatthias Springer adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType())); 7627b86f7c5SPeiming Liu SmallVector<Value> fields; 7637b86f7c5SPeiming Liu fields.reserve(desc.getNumFields()); 7647b86f7c5SPeiming Liu // Memcpy on memref fields. 7657b86f7c5SPeiming Liu for (auto field : desc.getMemRefFields()) { 7665550c821STres Popp auto memrefTp = cast<MemRefType>(field.getType()); 7677b86f7c5SPeiming Liu auto size = rewriter.create<memref::DimOp>(loc, field, 0); 7687b86f7c5SPeiming Liu auto copied = 7697b86f7c5SPeiming Liu rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size}); 7707b86f7c5SPeiming Liu rewriter.create<memref::CopyOp>(loc, field, copied); 7717b86f7c5SPeiming Liu fields.push_back(copied); 7727b86f7c5SPeiming Liu } 7737b86f7c5SPeiming Liu // Reuses specifier. 7747b86f7c5SPeiming Liu fields.push_back(desc.getSpecifier()); 7757b86f7c5SPeiming Liu assert(fields.size() == desc.getNumFields()); 776aed43562SMatthias Springer rewriter.replaceOpWithMultiple(op, {fields}); 7777b86f7c5SPeiming Liu return success(); 7787b86f7c5SPeiming Liu } 7797b86f7c5SPeiming Liu 7802cc4b3d0SPeiming Liu if (!resType.isIdentity()) { 7812cc4b3d0SPeiming Liu return rewriter.notifyMatchFailure( 7822cc4b3d0SPeiming Liu op, "try run --sparse-reinterpret-map before codegen"); 7832cc4b3d0SPeiming Liu } 7842cc4b3d0SPeiming Liu // Level size equals to dimension size since lvl2dim map is an identity map. 78583cf0dc9SAart Bik SmallVector<Value> lvlSizesValues; 786*9df63b26SMatthias Springer createDimSizes(rewriter, loc, resType, 787*9df63b26SMatthias Springer flattenValues(adaptor.getDynamicSizes()), 7882cc4b3d0SPeiming Liu /*dimSizesValues=*/lvlSizesValues); 78983cf0dc9SAart Bik 790160d483bSAart Bik // Construct allocation for each field. 791160d483bSAart Bik Value sizeHint = op.getSizeHint(); 7920e1708ffSAart Bik SmallVector<Value> fields; 79383cf0dc9SAart Bik createAllocFields(rewriter, loc, resType, enableBufferInitialization, 79483cf0dc9SAart Bik sizeHint, lvlSizesValues, fields); 795160d483bSAart Bik 796d22df0ebSAart Bik // Replace operation with resulting memrefs. 797aed43562SMatthias Springer rewriter.replaceOpWithMultiple(op, {fields}); 7980c7abd39SAart Bik return success(); 7990c7abd39SAart Bik } 8007276b643Sbixia1 8017276b643Sbixia1 private: 8027276b643Sbixia1 bool enableBufferInitialization; 8030c7abd39SAart Bik }; 8040c7abd39SAart Bik 8053e4a8c2cSAart Bik /// Sparse codegen rule for the empty tensor operator. 8063e4a8c2cSAart Bik class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { 8073e4a8c2cSAart Bik public: 8083e4a8c2cSAart Bik using OpConversionPattern::OpConversionPattern; 809206fad0eSMatthias Springer SparseTensorEmptyConverter(const TypeConverter &typeConverter, 810206fad0eSMatthias Springer MLIRContext *context, bool enableInit) 8113e4a8c2cSAart Bik : OpConversionPattern(typeConverter, context), 8123e4a8c2cSAart Bik enableBufferInitialization(enableInit) {} 8133e4a8c2cSAart Bik 8143e4a8c2cSAart Bik LogicalResult 8153e4a8c2cSAart Bik matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, 8163e4a8c2cSAart Bik ConversionPatternRewriter &rewriter) const override { 8173e4a8c2cSAart Bik const auto resType = getSparseTensorType(op); 8183e4a8c2cSAart Bik if (!resType.hasEncoding()) 8193e4a8c2cSAart Bik return failure(); 8202cc4b3d0SPeiming Liu 8212cc4b3d0SPeiming Liu if (!resType.isIdentity()) { 8222cc4b3d0SPeiming Liu return rewriter.notifyMatchFailure( 8232cc4b3d0SPeiming Liu op, "try run --sparse-reinterpret-map before codegen"); 8242cc4b3d0SPeiming Liu } 8252cc4b3d0SPeiming Liu 82683cf0dc9SAart Bik Location loc = op.getLoc(); 8272cc4b3d0SPeiming Liu // Level size equals to dimension size since lvl2dim map is an identity map. 82883cf0dc9SAart Bik SmallVector<Value> lvlSizesValues; 82983cf0dc9SAart Bik createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(), 8302cc4b3d0SPeiming Liu /*dimSizesValues=*/lvlSizesValues); 8313e4a8c2cSAart Bik // Construct allocation for each field. 832160d483bSAart Bik Value sizeHint; // none 8333e4a8c2cSAart Bik SmallVector<Value> fields; 83483cf0dc9SAart Bik createAllocFields(rewriter, loc, resType, enableBufferInitialization, 83583cf0dc9SAart Bik sizeHint, lvlSizesValues, fields); 836160d483bSAart Bik 8373e4a8c2cSAart Bik // Replace operation with resulting memrefs. 838aed43562SMatthias Springer rewriter.replaceOpWithMultiple(op, {fields}); 8393e4a8c2cSAart Bik return success(); 8403e4a8c2cSAart Bik } 8413e4a8c2cSAart Bik 8423e4a8c2cSAart Bik private: 8433e4a8c2cSAart Bik bool enableBufferInitialization; 8443e4a8c2cSAart Bik }; 8453e4a8c2cSAart Bik 8462ddfacd9SAart Bik /// Sparse codegen rule for the dealloc operator. 8472ddfacd9SAart Bik class SparseTensorDeallocConverter 8482ddfacd9SAart Bik : public OpConversionPattern<bufferization::DeallocTensorOp> { 8492ddfacd9SAart Bik public: 8502ddfacd9SAart Bik using OpConversionPattern::OpConversionPattern; 851206fad0eSMatthias Springer SparseTensorDeallocConverter(const TypeConverter &typeConverter, 852c44d307cSPeiming Liu MLIRContext *context, bool createDeallocs) 853c44d307cSPeiming Liu : OpConversionPattern(typeConverter, context), 854c44d307cSPeiming Liu createDeallocs(createDeallocs) {} 855c44d307cSPeiming Liu 8562ddfacd9SAart Bik LogicalResult 857*9df63b26SMatthias Springer matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor, 8582ddfacd9SAart Bik ConversionPatternRewriter &rewriter) const override { 8592ddfacd9SAart Bik auto enc = getSparseTensorEncoding(op.getTensor().getType()); 8602ddfacd9SAart Bik if (!enc) 8612ddfacd9SAart Bik return failure(); 862edca72f5SPeiming Liu 863c44d307cSPeiming Liu // If user requests not to deallocate sparse tensors, simply erase the 864c44d307cSPeiming Liu // operation. 865c44d307cSPeiming Liu if (createDeallocs) { 866edca72f5SPeiming Liu // Replace the sparse tensor deallocation with field deallocations. 867edca72f5SPeiming Liu Location loc = op.getLoc(); 868204234a6SMatthias Springer auto desc = getDescriptorFromTensorTuple( 869204234a6SMatthias Springer adaptor.getTensor(), 870204234a6SMatthias Springer cast<RankedTensorType>(op.getTensor().getType())); 871988733c6SPeiming Liu for (auto input : desc.getMemRefFields()) 872edca72f5SPeiming Liu // Deallocate every buffer used to store the sparse tensor handler. 873edca72f5SPeiming Liu rewriter.create<memref::DeallocOp>(loc, input); 874c44d307cSPeiming Liu } 8752ddfacd9SAart Bik rewriter.eraseOp(op); 8762ddfacd9SAart Bik return success(); 8772ddfacd9SAart Bik } 878c44d307cSPeiming Liu 879c44d307cSPeiming Liu private: 880fd2211d8SPeiming Liu const bool createDeallocs; 8812ddfacd9SAart Bik }; 8822ddfacd9SAart Bik 8830c7abd39SAart Bik /// Sparse codegen rule for tensor rematerialization. 8840c7abd39SAart Bik class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 8850c7abd39SAart Bik public: 8860c7abd39SAart Bik using OpConversionPattern::OpConversionPattern; 8870c7abd39SAart Bik LogicalResult 888*9df63b26SMatthias Springer matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor, 8890c7abd39SAart Bik ConversionPatternRewriter &rewriter) const override { 890191c43f6SPeiming Liu // Prepare descriptor. 891204234a6SMatthias Springer auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 892204234a6SMatthias Springer op.getTensor().getType()); 8939f596a7cSAart Bik // Generate optional insertion finalization code. 8949f596a7cSAart Bik if (op.getHasInserts()) 895191c43f6SPeiming Liu genEndInsert(rewriter, op.getLoc(), desc); 896d22df0ebSAart Bik // Replace operation with resulting memrefs. 897aed43562SMatthias Springer rewriter.replaceOpWithMultiple(op, {desc.getFields()}); 8980c7abd39SAart Bik return success(); 8990c7abd39SAart Bik } 9000c7abd39SAart Bik }; 9010c7abd39SAart Bik 9028a583bd5Sbixia1 /// Sparse codegen rule for the expand op. 9038a583bd5Sbixia1 class SparseExpandConverter : public OpConversionPattern<ExpandOp> { 9048a583bd5Sbixia1 public: 9058a583bd5Sbixia1 using OpConversionPattern::OpConversionPattern; 9068a583bd5Sbixia1 LogicalResult 907*9df63b26SMatthias Springer matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor, 9088a583bd5Sbixia1 ConversionPatternRewriter &rewriter) const override { 909191c43f6SPeiming Liu if (!getSparseTensorEncoding(op.getTensor().getType())) 910191c43f6SPeiming Liu return failure(); 9118a583bd5Sbixia1 Location loc = op->getLoc(); 912204234a6SMatthias Springer auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 913204234a6SMatthias Springer op.getTensor().getType()); 914f708a549Swren romano const auto srcType = getSparseTensorType(op.getTensor()); 9158a583bd5Sbixia1 Type eltType = srcType.getElementType(); 9168a583bd5Sbixia1 Type boolType = rewriter.getIntegerType(1); 9178a583bd5Sbixia1 Type idxType = rewriter.getIndexType(); 9188a583bd5Sbixia1 // All initialization should be done on entry of the loop nest. 9198a583bd5Sbixia1 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); 920c780352dSPeiming Liu 9218a583bd5Sbixia1 // Determine the size for access expansion (always the innermost stored 922c780352dSPeiming Liu // level size). 923c780352dSPeiming Liu const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1); 9248a583bd5Sbixia1 // Generate a memref for `sz` elements of type `t`. 925f708a549Swren romano const auto genAlloc = [&](Type t) { 926f708a549Swren romano const auto memTp = MemRefType::get({ShapedType::kDynamic}, t); 92783a50839SPeiming Liu return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); 9288a583bd5Sbixia1 }; 9293986c869SAart Bik // Allocate temporary buffers for values/filled-switch and added. 9308a583bd5Sbixia1 // We do not use stack buffers for this, since the expanded size may 9318a583bd5Sbixia1 // be rather large (as it envelops a single expanded dense dimension). 9328a583bd5Sbixia1 Value values = genAlloc(eltType); 9338a583bd5Sbixia1 Value filled = genAlloc(boolType); 9343986c869SAart Bik Value added = genAlloc(idxType); 9358a583bd5Sbixia1 Value zero = constantZero(rewriter, loc, idxType); 9368a583bd5Sbixia1 // Reset the values/filled-switch to all-zero/false. Note that this 9378a583bd5Sbixia1 // introduces an O(N) operation into the computation, but this reset 9388a583bd5Sbixia1 // operation is amortized over the innermost loops for the access 9398a583bd5Sbixia1 // pattern expansion. As noted in the operation doc, we would like 9408a583bd5Sbixia1 // to amortize this setup cost even between kernels. 9418a583bd5Sbixia1 rewriter.create<linalg::FillOp>( 9428a583bd5Sbixia1 loc, ValueRange{constantZero(rewriter, loc, eltType)}, 9438a583bd5Sbixia1 ValueRange{values}); 9448a583bd5Sbixia1 rewriter.create<linalg::FillOp>( 9458a583bd5Sbixia1 loc, ValueRange{constantZero(rewriter, loc, boolType)}, 9468a583bd5Sbixia1 ValueRange{filled}); 94784cd51bbSwren romano // Replace expansion op with these buffers and initial coordinate. 9488a583bd5Sbixia1 assert(op.getNumResults() == 4); 9493986c869SAart Bik rewriter.replaceOp(op, {values, filled, added, zero}); 9503986c869SAart Bik return success(); 9513986c869SAart Bik } 9523986c869SAart Bik }; 9533986c869SAart Bik 9543986c869SAart Bik /// Sparse codegen rule for the compress operator. 9553986c869SAart Bik class SparseCompressConverter : public OpConversionPattern<CompressOp> { 9563986c869SAart Bik public: 9573986c869SAart Bik using OpConversionPattern::OpConversionPattern; 9583986c869SAart Bik LogicalResult 959*9df63b26SMatthias Springer matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor, 9603986c869SAart Bik ConversionPatternRewriter &rewriter) const override { 9613986c869SAart Bik Location loc = op->getLoc(); 962191c43f6SPeiming Liu SmallVector<Value> fields; 963204234a6SMatthias Springer auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields, 964204234a6SMatthias Springer op.getTensor().getType()); 965*9df63b26SMatthias Springer Value values = getSingleValue(adaptor.getValues()); 966*9df63b26SMatthias Springer Value filled = getSingleValue(adaptor.getFilled()); 967*9df63b26SMatthias Springer Value added = getSingleValue(adaptor.getAdded()); 968*9df63b26SMatthias Springer Value count = getSingleValue(adaptor.getCount()); 969f708a549Swren romano const SparseTensorType dstType(desc.getRankedTensorType()); 970191c43f6SPeiming Liu Type eltType = dstType.getElementType(); 971ad469385SPeiming Liu 97284cd51bbSwren romano // If the innermost level is ordered, we need to sort the coordinates 9734d068619SAart Bik // in the "added" array prior to applying the compression. 974f708a549Swren romano if (dstType.isOrderedLvl(dstType.getLvlRank() - 1)) 9750083f833SPeiming Liu rewriter.create<SortOp>( 976bfa3bc43SPeiming Liu loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1), 977bfa3bc43SPeiming Liu rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort); 9783986c869SAart Bik // While performing the insertions, we also need to reset the elements 9793986c869SAart Bik // of the values/filled-switch by only iterating over the set elements, 9803986c869SAart Bik // to ensure that the runtime complexity remains proportional to the 9813986c869SAart Bik // sparsity of the expanded access pattern. 9823986c869SAart Bik // 9833986c869SAart Bik // Generate 984d22df0ebSAart Bik // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) { 98584cd51bbSwren romano // crd = added[i]; 98684cd51bbSwren romano // value = values[crd]; 98784cd51bbSwren romano // insert({lvlCoords, crd}, value); 98884cd51bbSwren romano // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value); 98984cd51bbSwren romano // values[crd] = 0; 99084cd51bbSwren romano // filled[crd] = false; 991d22df0ebSAart Bik // yield new_memrefs 9923986c869SAart Bik // } 993191c43f6SPeiming Liu scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields()); 994d22df0ebSAart Bik Value i = loop.getInductionVar(); 995ad469385SPeiming Liu 99684cd51bbSwren romano Value crd = genLoad(rewriter, loc, added, i); 99784cd51bbSwren romano Value value = genLoad(rewriter, loc, values, crd); 998ad469385SPeiming Liu SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end()); 999ad469385SPeiming Liu SmallVector<Type> flatSpTensorTps = llvm::to_vector( 1000ad469385SPeiming Liu llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); })); 1001*9df63b26SMatthias Springer SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords()); 1002*9df63b26SMatthias Springer params.append(flatLvlCoords.begin(), flatLvlCoords.end()); 1003ad469385SPeiming Liu params.push_back(crd); 1004ad469385SPeiming Liu params.push_back(value); 1005ad469385SPeiming Liu SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps, 1006ad469385SPeiming Liu params, /*genCall=*/true); 1007ad469385SPeiming Liu SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc); 100884cd51bbSwren romano genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd); 100984cd51bbSwren romano genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd); 1010ad469385SPeiming Liu rewriter.create<scf::YieldOp>(loc, insertRet); 1011ad469385SPeiming Liu 1012129177eaSAart Bik rewriter.setInsertionPointAfter(loop); 10133986c869SAart Bik // Deallocate the buffers on exit of the full loop nest. 10145661647eSAart Bik Operation *parent = getTop(op); 10153986c869SAart Bik rewriter.setInsertionPointAfter(parent); 10163986c869SAart Bik rewriter.create<memref::DeallocOp>(loc, values); 10173986c869SAart Bik rewriter.create<memref::DeallocOp>(loc, filled); 10183986c869SAart Bik rewriter.create<memref::DeallocOp>(loc, added); 1019d22df0ebSAart Bik // Replace operation with resulting memrefs. 1020aed43562SMatthias Springer rewriter.replaceOpWithMultiple(op, {loop->getResults()}); 10219f596a7cSAart Bik return success(); 10229f596a7cSAart Bik } 10239f596a7cSAart Bik }; 10249f596a7cSAart Bik 10259f596a7cSAart Bik /// Sparse codegen rule for the insert operator. 102694e27c26SPeiming Liu class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> { 10279f596a7cSAart Bik public: 10289f596a7cSAart Bik using OpConversionPattern::OpConversionPattern; 10299f596a7cSAart Bik LogicalResult 1030*9df63b26SMatthias Springer matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor, 10319f596a7cSAart Bik ConversionPatternRewriter &rewriter) const override { 1032*9df63b26SMatthias Springer auto stt = getSparseTensorType(op.getDest()); 103394e27c26SPeiming Liu if (!stt.hasEncoding()) 103494e27c26SPeiming Liu return failure(); 103594e27c26SPeiming Liu assert(stt.isIdentity() && "Run reinterpret-map before conversion."); 103694e27c26SPeiming Liu 1037ad469385SPeiming Liu Location loc = op.getLoc(); 1038204234a6SMatthias Springer auto desc = 1039204234a6SMatthias Springer getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType()); 1040ad469385SPeiming Liu TypeRange flatSpTensorTps = desc.getFields().getTypes(); 1041ad469385SPeiming Liu SmallVector<Value> params = llvm::to_vector(desc.getFields()); 1042*9df63b26SMatthias Springer SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices()); 1043*9df63b26SMatthias Springer params.append(flatIndices.begin(), flatIndices.end()); 1044*9df63b26SMatthias Springer params.push_back(getSingleValue(adaptor.getScalar())); 104594e27c26SPeiming Liu SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps, 1046ad469385SPeiming Liu params, /*genCall=*/true); 1047ad469385SPeiming Liu SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc); 1048d22df0ebSAart Bik // Replace operation with resulting memrefs. 1049aed43562SMatthias Springer rewriter.replaceOpWithMultiple(op, {ret}); 10508a583bd5Sbixia1 return success(); 10518a583bd5Sbixia1 } 10528a583bd5Sbixia1 }; 10538a583bd5Sbixia1 105484cd51bbSwren romano /// Sparse codegen rule for position accesses. 105584cd51bbSwren romano class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> { 10566607fdf7SAart Bik public: 105784cd51bbSwren romano using OpAdaptor = typename ToPositionsOp::Adaptor; 105884cd51bbSwren romano using OpConversionPattern<ToPositionsOp>::OpConversionPattern; 10596607fdf7SAart Bik LogicalResult 1060*9df63b26SMatthias Springer matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor, 10616607fdf7SAart Bik ConversionPatternRewriter &rewriter) const override { 106284cd51bbSwren romano // Replace the requested position access with corresponding field. 10635c511655SAart Bik // The view is restricted to the actual size to ensure clients 10645c511655SAart Bik // of this operation truly observe size, not capacity! 10655c511655SAart Bik Location loc = op.getLoc(); 10665c511655SAart Bik Level lvl = op.getLevel(); 1067204234a6SMatthias Springer auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 1068204234a6SMatthias Springer op.getTensor().getType()); 10695c511655SAart Bik auto mem = desc.getPosMemRef(lvl); 10705c511655SAart Bik auto size = desc.getPosMemSize(rewriter, loc, lvl); 10715c511655SAart Bik rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); 10726607fdf7SAart Bik return success(); 10736607fdf7SAart Bik } 10746607fdf7SAart Bik }; 10756607fdf7SAart Bik 107684cd51bbSwren romano /// Sparse codegen rule for accessing the coordinates arrays. 107784cd51bbSwren romano class SparseToCoordinatesConverter 107884cd51bbSwren romano : public OpConversionPattern<ToCoordinatesOp> { 1079edca72f5SPeiming Liu public: 108084cd51bbSwren romano using OpAdaptor = typename ToCoordinatesOp::Adaptor; 108184cd51bbSwren romano using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern; 108290aa4362Sbixia1 LogicalResult 1083*9df63b26SMatthias Springer matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor, 108490aa4362Sbixia1 ConversionPatternRewriter &rewriter) const override { 108584cd51bbSwren romano // Replace the requested coordinates access with corresponding field. 10865c511655SAart Bik // The view is restricted to the actual size to ensure clients 10875c511655SAart Bik // of this operation truly observe size, not capacity! 10885c511655SAart Bik Location loc = op.getLoc(); 10895c511655SAart Bik Level lvl = op.getLevel(); 1090204234a6SMatthias Springer auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 1091204234a6SMatthias Springer op.getTensor().getType()); 10925c511655SAart Bik auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl); 10935c511655SAart Bik if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) { 10945c511655SAart Bik auto size = desc.getCrdMemSize(rewriter, loc, lvl); 10955c511655SAart Bik mem = genSliceToSize(rewriter, loc, mem, size); 10965c511655SAart Bik } 10975c511655SAart Bik rewriter.replaceOp(op, mem); 109890aa4362Sbixia1 return success(); 1099edca72f5SPeiming Liu } 1100edca72f5SPeiming Liu }; 1101edca72f5SPeiming Liu 110284cd51bbSwren romano /// Sparse codegen rule for accessing the linear coordinates buffer. 110384cd51bbSwren romano class SparseToCoordinatesBufferConverter 110484cd51bbSwren romano : public OpConversionPattern<ToCoordinatesBufferOp> { 110581e3079dSbixia1 public: 110684cd51bbSwren romano using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor; 110784cd51bbSwren romano using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern; 110881e3079dSbixia1 LogicalResult 1109*9df63b26SMatthias Springer matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor, 111081e3079dSbixia1 ConversionPatternRewriter &rewriter) const override { 111184cd51bbSwren romano // Replace the requested coordinates access with corresponding field. 11125c511655SAart Bik // The view is restricted to the actual size to ensure clients 11135c511655SAart Bik // of this operation truly observe size, not capacity! 11145c511655SAart Bik Location loc = op.getLoc(); 11155c511655SAart Bik Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart(); 1116204234a6SMatthias Springer auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 1117204234a6SMatthias Springer op.getTensor().getType()); 11185c511655SAart Bik auto mem = desc.getAOSMemRef(); 11195c511655SAart Bik auto size = desc.getCrdMemSize(rewriter, loc, lvl); 11205c511655SAart Bik rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); 112181e3079dSbixia1 return success(); 112281e3079dSbixia1 } 112381e3079dSbixia1 }; 112481e3079dSbixia1 1125edca72f5SPeiming Liu /// Sparse codegen rule for value accesses. 112690aa4362Sbixia1 class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> { 1127edca72f5SPeiming Liu public: 112890aa4362Sbixia1 using OpAdaptor = typename ToValuesOp::Adaptor; 112990aa4362Sbixia1 using OpConversionPattern<ToValuesOp>::OpConversionPattern; 113090aa4362Sbixia1 LogicalResult 1131*9df63b26SMatthias Springer matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor, 113290aa4362Sbixia1 ConversionPatternRewriter &rewriter) const override { 113384cd51bbSwren romano // Replace the requested values access with corresponding field. 11345c511655SAart Bik // The view is restricted to the actual size to ensure clients 11355c511655SAart Bik // of this operation truly observe size, not capacity! 11365c511655SAart Bik Location loc = op.getLoc(); 1137204234a6SMatthias Springer auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 1138204234a6SMatthias Springer op.getTensor().getType()); 11395c511655SAart Bik auto mem = desc.getValMemRef(); 11405c511655SAart Bik auto size = desc.getValMemSize(rewriter, loc); 11415c511655SAart Bik rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); 114290aa4362Sbixia1 return success(); 1143edca72f5SPeiming Liu } 1144edca72f5SPeiming Liu }; 1145edca72f5SPeiming Liu 114658b449c3Sbixia1 /// Sparse codegen rule for the convert operator. 114758b449c3Sbixia1 class SparseConvertConverter : public OpConversionPattern<ConvertOp> { 114858b449c3Sbixia1 public: 114958b449c3Sbixia1 using OpConversionPattern::OpConversionPattern; 115058b449c3Sbixia1 LogicalResult 1151*9df63b26SMatthias Springer matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor, 115258b449c3Sbixia1 ConversionPatternRewriter &rewriter) const override { 11530128f801Sbixia1 SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType()); 11540128f801Sbixia1 SparseTensorEncodingAttr encSrc = 11550128f801Sbixia1 getSparseTensorEncoding(op.getSource().getType()); 115633267f40SPeiming Liu // The output tensor can not be a slice and those cases should have been 115733267f40SPeiming Liu // rejected by ConvertOp::verify() already. 115833267f40SPeiming Liu assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices."); 115985dbb3fcSPeiming Liu // Different encoding (except for different bitwidth) should be handled by 116085dbb3fcSPeiming Liu // rewriting. 116133267f40SPeiming Liu // We need further rewrites if the input tensor is a slice too. 116233267f40SPeiming Liu if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() || 116333267f40SPeiming Liu encSrc.isSlice()) { 116458b449c3Sbixia1 return failure(); 116558b449c3Sbixia1 } 116685dbb3fcSPeiming Liu 116785dbb3fcSPeiming Liu Type retElemTp = op.getResult().getType().getElementType(); 116885dbb3fcSPeiming Liu Type srcElemTp = op.getSource().getType().getElementType(); 116985dbb3fcSPeiming Liu // Fold the trivial cases. 117085dbb3fcSPeiming Liu if (retElemTp == srcElemTp && encDst == encSrc) { 1171*9df63b26SMatthias Springer rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); 117258b449c3Sbixia1 return success(); 117358b449c3Sbixia1 } 117485dbb3fcSPeiming Liu // 117585dbb3fcSPeiming Liu // Do element-wise type conversion without using InsertOp. 117685dbb3fcSPeiming Liu // 117785dbb3fcSPeiming Liu // for each memref in srcTensor: 117885dbb3fcSPeiming Liu // dst = memref.alloc 117985dbb3fcSPeiming Liu // if srcMemRefType != dstMemRefType: 118085dbb3fcSPeiming Liu // for every dst[i] = cast(src[i]) 118185dbb3fcSPeiming Liu // else: 118285dbb3fcSPeiming Liu // dst = memref.copy(src) 118385dbb3fcSPeiming Liu Location loc = op.getLoc(); 1184204234a6SMatthias Springer auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(), 1185204234a6SMatthias Springer op.getSource().getType()); 118685dbb3fcSPeiming Liu SmallVector<Value> fields; 118785dbb3fcSPeiming Liu foreachFieldAndTypeInSparseTensor( 11885550c821STres Popp SparseTensorType(cast<RankedTensorType>(op.getResult().getType())), 118985dbb3fcSPeiming Liu [&rewriter, &fields, srcDesc, 119085dbb3fcSPeiming Liu loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl, 11911944c4f7SAart Bik LevelType /*lt*/) -> bool { 119285dbb3fcSPeiming Liu // Simply reuses the storage specifier as it is an SSA value. 119385dbb3fcSPeiming Liu if (fKind == SparseTensorFieldKind::StorageSpec) { 119485dbb3fcSPeiming Liu fields.push_back(srcDesc.getSpecifier()); 119585dbb3fcSPeiming Liu } else { 119685dbb3fcSPeiming Liu // Allocates new memrefs 119785dbb3fcSPeiming Liu Value srcMem = srcDesc.getMemRefField(fIdx); 119885dbb3fcSPeiming Liu // TODO: We can instead use the actual memSize in specifier, that 119985dbb3fcSPeiming Liu // would require a subViewOp to avoid overflow when copying 120085dbb3fcSPeiming Liu // values. 120185dbb3fcSPeiming Liu Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0); 120285dbb3fcSPeiming Liu auto dstMem = rewriter.create<memref::AllocOp>( 12035550c821STres Popp loc, cast<MemRefType>(fTp), sz); 120485dbb3fcSPeiming Liu if (fTp != srcMem.getType()) { 120585dbb3fcSPeiming Liu // Converts elements type. 120685dbb3fcSPeiming Liu scf::buildLoopNest( 120785dbb3fcSPeiming Liu rewriter, loc, constantIndex(rewriter, loc, 0), sz, 120885dbb3fcSPeiming Liu constantIndex(rewriter, loc, 1), 120985dbb3fcSPeiming Liu [srcMem, &dstMem](OpBuilder &builder, Location loc, 121085dbb3fcSPeiming Liu ValueRange ivs) { 121185dbb3fcSPeiming Liu Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs); 121285dbb3fcSPeiming Liu Value casted = genCast(builder, loc, v, 121385dbb3fcSPeiming Liu dstMem.getType().getElementType()); 121485dbb3fcSPeiming Liu builder.create<memref::StoreOp>(loc, casted, dstMem, ivs); 121585dbb3fcSPeiming Liu }); 121685dbb3fcSPeiming Liu } else { 121785dbb3fcSPeiming Liu // TODO: We can even reuse the same memref for the new tensor, 121885dbb3fcSPeiming Liu // but that requires a `ref-counting` based memory management 121985dbb3fcSPeiming Liu // for shared memrefs between multiple sparse tensors. 122085dbb3fcSPeiming Liu rewriter.create<memref::CopyOp>(loc, srcMem, dstMem); 122185dbb3fcSPeiming Liu } 122285dbb3fcSPeiming Liu fields.push_back(dstMem); 122385dbb3fcSPeiming Liu } 122485dbb3fcSPeiming Liu return true; 122585dbb3fcSPeiming Liu }); 122685dbb3fcSPeiming Liu 1227aed43562SMatthias Springer rewriter.replaceOpWithMultiple(op, {fields}); 122885dbb3fcSPeiming Liu return success(); 122985dbb3fcSPeiming Liu } 123058b449c3Sbixia1 }; 123158b449c3Sbixia1 12326db397a8SPeiming Liu class SparseExtractSliceConverter 123303526904SPeiming Liu : public OpConversionPattern<tensor::ExtractSliceOp> { 123403526904SPeiming Liu public: 123503526904SPeiming Liu using OpConversionPattern::OpConversionPattern; 123603526904SPeiming Liu LogicalResult 1237*9df63b26SMatthias Springer matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor, 123803526904SPeiming Liu ConversionPatternRewriter &rewriter) const override { 12396db397a8SPeiming Liu Location loc = op.getLoc(); 12406db397a8SPeiming Liu MLIRContext *ctx = op.getContext(); 124103526904SPeiming Liu auto srcEnc = getSparseTensorEncoding(op.getSourceType()); 124203526904SPeiming Liu auto dstEnc = getSparseTensorEncoding(op.getResult().getType()); 124303526904SPeiming Liu // TODO: We should check these in ExtractSliceOp::verify. 1244dbdb4affSAart Bik if (!srcEnc || !dstEnc || !dstEnc.isSlice()) 1245dbdb4affSAart Bik return failure(); 1246af2bec7cSwren romano assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices()); 124703526904SPeiming Liu 12486db397a8SPeiming Liu SmallVector<Value> fields; 1249204234a6SMatthias Springer auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields, 1250204234a6SMatthias Springer op.getSource().getType()); 12516db397a8SPeiming Liu 12526db397a8SPeiming Liu auto newSpec = rewriter.create<StorageSpecifierInitOp>( 12536db397a8SPeiming Liu loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier()); 12546db397a8SPeiming Liu desc.setSpecifier(newSpec); 12556db397a8SPeiming Liu 12566db397a8SPeiming Liu // Fills in slice information. 1257a0a76804SJakub Kuderski for (auto [idx, offset, size, stride] : llvm::enumerate( 1258a0a76804SJakub Kuderski op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) { 1259a0a76804SJakub Kuderski Dimension dim = idx; 12606db397a8SPeiming Liu 12616db397a8SPeiming Liu Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset); 12626db397a8SPeiming Liu Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size); 12636db397a8SPeiming Liu Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride); 12646db397a8SPeiming Liu // TODO: We could probably only set dynamic value here. But it would 12656db397a8SPeiming Liu // requires us to fill the hole when casting a static slice to dynamic 12666db397a8SPeiming Liu // slice. 12676db397a8SPeiming Liu desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset, 12686db397a8SPeiming Liu dim, offsetV); 12696db397a8SPeiming Liu 12706db397a8SPeiming Liu // FIXME: we need to distinguish level sizes and dimension size for slices 12716db397a8SPeiming Liu // here. Maybe we should store slice level sizes in a different array 12726db397a8SPeiming Liu // instead of reusing it. 127376647fceSwren romano assert(srcEnc.isIdentity()); 12746db397a8SPeiming Liu desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim, 12756db397a8SPeiming Liu sizeV); 12766db397a8SPeiming Liu desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride, 12776db397a8SPeiming Liu dim, strideV); 127803526904SPeiming Liu } 127903526904SPeiming Liu 12806db397a8SPeiming Liu // NOTE: we can not generate tuples directly from descriptor here, as the 12816db397a8SPeiming Liu // descriptor is holding the original type, yet we want the slice type 12826db397a8SPeiming Liu // here (they shared every memref but with an updated specifier). 1283aed43562SMatthias Springer rewriter.replaceOpWithMultiple(op, {desc.getFields()}); 128403526904SPeiming Liu return success(); 128503526904SPeiming Liu } 128603526904SPeiming Liu }; 128703526904SPeiming Liu 12880f3e4d1aSAart Bik /// Sparse codegen rule for number of entries operator. 12890f3e4d1aSAart Bik class SparseNumberOfEntriesConverter 12900f3e4d1aSAart Bik : public OpConversionPattern<NumberOfEntriesOp> { 12910f3e4d1aSAart Bik public: 12920f3e4d1aSAart Bik using OpConversionPattern::OpConversionPattern; 12930f3e4d1aSAart Bik LogicalResult 1294*9df63b26SMatthias Springer matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor, 12950f3e4d1aSAart Bik ConversionPatternRewriter &rewriter) const override { 129663d31a4dSbixia1 // Query memSizes for the actually stored values. 1297de560888SPeiming Liu // FIXME: the nse value computed in this way might be wrong when there is 1298d2e85179SYinying Li // any "loose_compressed" level. 1299204234a6SMatthias Springer auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 1300204234a6SMatthias Springer op.getTensor().getType()); 1301204234a6SMatthias Springer rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc())); 13020f3e4d1aSAart Bik return success(); 13030f3e4d1aSAart Bik } 13040f3e4d1aSAart Bik }; 13050f3e4d1aSAart Bik 13066ca47eb4SPeiming Liu struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> { 1307a41672e1SPeiming Liu using OpConversionPattern::OpConversionPattern; 1308a41672e1SPeiming Liu LogicalResult 13096ca47eb4SPeiming Liu matchAndRewrite(AssembleOp op, OpAdaptor adaptor, 1310a41672e1SPeiming Liu ConversionPatternRewriter &rewriter) const override { 1311de560888SPeiming Liu Location loc = op.getLoc(); 131234c9c59cSwren romano const auto stt = getSparseTensorType(op.getResult()); 1313a41672e1SPeiming Liu 1314a41672e1SPeiming Liu SmallVector<Value> fields; 1315a41672e1SPeiming Liu 1316a41672e1SPeiming Liu foreachFieldAndTypeInSparseTensor( 131734c9c59cSwren romano stt, 1318de560888SPeiming Liu [&rewriter, &fields, &op, &stt, 1319f708a549Swren romano loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, 13201944c4f7SAart Bik Level /*lvl*/, LevelType lt) -> bool { 1321a41672e1SPeiming Liu assert(fields.size() == fIdx); 1322de560888SPeiming Liu if (fKind == SparseTensorFieldKind::StorageSpec) { 1323de560888SPeiming Liu fields.push_back( 1324de560888SPeiming Liu SparseTensorSpecifier::getInitValue(rewriter, loc, stt)); 13257864d736SPeiming Liu } else { 1326de560888SPeiming Liu // Else simply takes the inputs. 1327b2e6b735SPeiming Liu Value tensor = fKind == SparseTensorFieldKind::ValMemRef 1328de560888SPeiming Liu ? op.getValues() 1329de560888SPeiming Liu : op.getLevels()[fIdx]; 13300d1f9576SPeiming Liu // TODO: handle batch. 1331b2e6b735SPeiming Liu TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor); 133252b69aa3SPeiming Liu if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) { 133352b69aa3SPeiming Liu // Flattens the buffer to batchLvlRank. 133452b69aa3SPeiming Liu auto reassoc = getReassociationForFlattening( 133552b69aa3SPeiming Liu mem.getType(), stt.getBatchLvlRank()); 1336b2e6b735SPeiming Liu mem = rewriter.create<memref::CastOp>( 1337b2e6b735SPeiming Liu loc, fType, 1338b2e6b735SPeiming Liu rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc)); 1339b2e6b735SPeiming Liu } else { 1340b2e6b735SPeiming Liu mem = rewriter.create<memref::CastOp>(loc, fType, mem); 1341a41672e1SPeiming Liu } 1342b2e6b735SPeiming Liu fields.push_back(mem); 1343de560888SPeiming Liu } 1344a41672e1SPeiming Liu return true; 1345a41672e1SPeiming Liu }); 1346a41672e1SPeiming Liu 134734c9c59cSwren romano MutSparseTensorDescriptor desc(stt, fields); 1348f7b8b005SPeiming Liu Value c0 = constantIndex(rewriter, loc, 0); 1349de560888SPeiming Liu Value c1 = constantIndex(rewriter, loc, 1); 1350de560888SPeiming Liu Value c2 = constantIndex(rewriter, loc, 2); 1351be556ee1SYinying Li Value posBack = c0; // index to the last value in the position array 1352f7b8b005SPeiming Liu Value memSize = c1; // memory size for current array 1353b2e6b735SPeiming Liu 13545248a987SPeiming Liu Level trailCOOStart = stt.getAoSCOOStart(); 1355b2e6b735SPeiming Liu Level trailCOORank = stt.getLvlRank() - trailCOOStart; 1356de560888SPeiming Liu // Sets up SparseTensorSpecifier. 135734c9c59cSwren romano for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { 1358de560888SPeiming Liu assert(!ShapedType::isDynamic(stt.getDimShape()[lvl])); 1359a41672e1SPeiming Liu 1360de560888SPeiming Liu // Sets up the level size. 13610d1f9576SPeiming Liu auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]); 1362de560888SPeiming Liu desc.setLvlSize(rewriter, loc, lvl, lvlSize); 1363b2e6b735SPeiming Liu // We use a single AOS array to store the trailing COO, so there is only 1364b2e6b735SPeiming Liu // one memory size to set for the entire COO section. 1365b2e6b735SPeiming Liu if (lvl > trailCOOStart) 1366b2e6b735SPeiming Liu continue; 1367de560888SPeiming Liu 1368de560888SPeiming Liu // Sets up the memory size by reading the last value in position array. 13691944c4f7SAart Bik LevelType lt = stt.getLvlType(lvl); 1370de560888SPeiming Liu // Simply forwards the position index when this is a dense level. 137152b69aa3SPeiming Liu if (lt.isa<LevelFormat::Dense>()) { 1372f7b8b005SPeiming Liu memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize); 1373de560888SPeiming Liu posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1); 1374de560888SPeiming Liu continue; 1375a41672e1SPeiming Liu } 137652b69aa3SPeiming Liu if (lt.isa<LevelFormat::Batch>()) { 137752b69aa3SPeiming Liu // Skips batch levels as it is not linearized. 137852b69aa3SPeiming Liu // FIXME: this assumes that every batch has the same number of nse, need 137952b69aa3SPeiming Liu // to be generalized to handle varied-size batches. 138052b69aa3SPeiming Liu continue; 138152b69aa3SPeiming Liu } 1382de560888SPeiming Liu 13831dd387e1SAart Bik if (isWithPosLT(lt)) { 13841dd387e1SAart Bik assert(isCompressedLT(lt) || isLooseCompressedLT(lt)); 13851dd387e1SAart Bik if (isLooseCompressedLT(lt)) { 1386de560888SPeiming Liu memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2); 1387de560888SPeiming Liu posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1); 1388f7b8b005SPeiming Liu } else { 13891dd387e1SAart Bik assert(isCompressedLT(lt)); 1390f7b8b005SPeiming Liu posBack = memSize; 1391f7b8b005SPeiming Liu memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1); 1392de560888SPeiming Liu } 1393de560888SPeiming Liu desc.setPosMemSize(rewriter, loc, lvl, memSize); 1394de560888SPeiming Liu // The last value in position array is the memory size for next level. 139552b69aa3SPeiming Liu // FIXME: this assumes that every batch has the same number of nse, need 139652b69aa3SPeiming Liu // to be generalized to handle varied-size batches. 139752b69aa3SPeiming Liu SmallVector<Value> batched(stt.getBatchLvlRank(), 139852b69aa3SPeiming Liu constantIndex(rewriter, loc, 0)); 139952b69aa3SPeiming Liu batched.push_back(posBack); 140052b69aa3SPeiming Liu memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched); 1401de560888SPeiming Liu posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1); 1402de560888SPeiming Liu } 14031dd387e1SAart Bik assert(isWithCrdLT(lt) && lvl <= trailCOOStart); 1404b2e6b735SPeiming Liu // FIXME: This seems to be unnecessarily complex, can we simplify it? 1405b2e6b735SPeiming Liu if (lvl == trailCOOStart) { 1406b2e6b735SPeiming Liu Value cooSz = rewriter.create<arith::MulIOp>( 1407b2e6b735SPeiming Liu loc, memSize, constantIndex(rewriter, loc, trailCOORank)); 1408b2e6b735SPeiming Liu desc.setCrdMemSize(rewriter, loc, lvl, cooSz); 1409b2e6b735SPeiming Liu } else { 1410de560888SPeiming Liu desc.setCrdMemSize(rewriter, loc, lvl, memSize); 1411de560888SPeiming Liu } 1412b2e6b735SPeiming Liu } 1413de560888SPeiming Liu desc.setValMemSize(rewriter, loc, memSize); 1414a41672e1SPeiming Liu 1415aed43562SMatthias Springer rewriter.replaceOpWithMultiple(op, {desc.getFields()}); 1416a41672e1SPeiming Liu return success(); 1417a41672e1SPeiming Liu } 1418a41672e1SPeiming Liu }; 1419a41672e1SPeiming Liu 14206ca47eb4SPeiming Liu struct SparseDisassembleOpConverter 14216ca47eb4SPeiming Liu : public OpConversionPattern<DisassembleOp> { 1422d4db5289SPeiming Liu using OpConversionPattern::OpConversionPattern; 1423206fad0eSMatthias Springer SparseDisassembleOpConverter(const TypeConverter &typeConverter, 14246ca47eb4SPeiming Liu MLIRContext *context) 1425de560888SPeiming Liu : OpConversionPattern(typeConverter, context) {} 1426d4db5289SPeiming Liu 1427d4db5289SPeiming Liu LogicalResult 1428*9df63b26SMatthias Springer matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor, 1429d4db5289SPeiming Liu ConversionPatternRewriter &rewriter) const override { 1430204234a6SMatthias Springer auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 1431204234a6SMatthias Springer op.getTensor().getType()); 1432b2e6b735SPeiming Liu Location loc = op.getLoc(); 1433b2e6b735SPeiming Liu SmallVector<Value> retMem; 1434a63d6a00SPeiming Liu SmallVector<Value> retLen; 14351944c4f7SAart Bik desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem, 14361944c4f7SAart Bik &retLen](FieldIndex fid, 14371944c4f7SAart Bik SparseTensorFieldKind fKind, 14381944c4f7SAart Bik Level lvl, LevelType lt) -> bool { 1439b2e6b735SPeiming Liu if (fKind == SparseTensorFieldKind::StorageSpec) 1440b2e6b735SPeiming Liu return true; 1441b2e6b735SPeiming Liu SparseTensorType stt(desc.getRankedTensorType()); 1442b2e6b735SPeiming Liu Value sz, src; 1443b2e6b735SPeiming Liu TypedValue<BaseMemRefType> dst; 1444b2e6b735SPeiming Liu if (fKind == SparseTensorFieldKind::ValMemRef) { 1445b2e6b735SPeiming Liu sz = desc.getValMemSize(rewriter, loc); 1446b2e6b735SPeiming Liu src = desc.getValMemRef(); 1447b2e6b735SPeiming Liu dst = genToMemref(rewriter, loc, op.getOutValues()); 1448fc9f1d49SPeiming Liu 1449fc9f1d49SPeiming Liu retMem.push_back(dst); 145064df1c08SPeiming Liu Type valLenTp = op.getValLen().getType(); 1451fc9f1d49SPeiming Liu retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp)); 1452b2e6b735SPeiming Liu } else { 1453b2e6b735SPeiming Liu assert(fKind == SparseTensorFieldKind::PosMemRef || 1454b2e6b735SPeiming Liu fKind == SparseTensorFieldKind::CrdMemRef); 1455b2e6b735SPeiming Liu 1456b2e6b735SPeiming Liu sz = fKind == SparseTensorFieldKind::PosMemRef 1457b2e6b735SPeiming Liu ? desc.getPosMemSize(rewriter, loc, lvl) 1458b2e6b735SPeiming Liu : desc.getCrdMemSize(rewriter, loc, lvl); 1459b2e6b735SPeiming Liu src = desc.getMemRefField(fid); 1460b2e6b735SPeiming Liu dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]); 1461b2e6b735SPeiming Liu retMem.push_back(dst); 146264df1c08SPeiming Liu // Retrieves the corresponding level length type. 146364df1c08SPeiming Liu Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()]; 1464098f46dcSPeiming Liu retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp)); 1465b2e6b735SPeiming Liu } 1466b2e6b735SPeiming Liu Value flatOut = dst; 146752b69aa3SPeiming Liu if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) { 146852b69aa3SPeiming Liu auto reassoc = 146952b69aa3SPeiming Liu getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank()); 1470b2e6b735SPeiming Liu flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc); 1471b2e6b735SPeiming Liu } 1472b2e6b735SPeiming Liu Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz); 1473b2e6b735SPeiming Liu Value srcMem = genSliceToSize(rewriter, loc, src, sz); 1474b2e6b735SPeiming Liu rewriter.create<memref::CopyOp>(loc, srcMem, dstMem); 1475b2e6b735SPeiming Liu return true; 1476b2e6b735SPeiming Liu }); 1477b2e6b735SPeiming Liu 1478b2e6b735SPeiming Liu // Converts MemRefs back to Tensors. 1479a63d6a00SPeiming Liu SmallVector<Value> retValues = llvm::to_vector( 1480b2e6b735SPeiming Liu llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value { 1481b2e6b735SPeiming Liu return rewriter.create<bufferization::ToTensorOp>(loc, v); 1482b2e6b735SPeiming Liu })); 1483a63d6a00SPeiming Liu // Appends the actual memory length used in each buffer returned. 1484a63d6a00SPeiming Liu retValues.append(retLen.begin(), retLen.end()); 1485a63d6a00SPeiming Liu rewriter.replaceOp(op, retValues); 1486b2e6b735SPeiming Liu return success(); 1487d4db5289SPeiming Liu } 1488dc6427d6SPeiming Liu }; 1489dc6427d6SPeiming Liu 1490d3af6535SAart Bik struct SparseNewConverter : public OpConversionPattern<NewOp> { 14912c81d432Sbixia1 using OpConversionPattern::OpConversionPattern; 14922c81d432Sbixia1 LogicalResult 14932c81d432Sbixia1 matchAndRewrite(NewOp op, OpAdaptor adaptor, 14942c81d432Sbixia1 ConversionPatternRewriter &rewriter) const override { 14952c81d432Sbixia1 Location loc = op.getLoc(); 14962c81d432Sbixia1 const auto dstTp = getSparseTensorType(op.getResult()); 14972c81d432Sbixia1 // Creating COO with NewOp is handled by direct IR codegen. All other cases 14982c81d432Sbixia1 // are handled by rewriting. 14995248a987SPeiming Liu if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0) 15002c81d432Sbixia1 return failure(); 15012c81d432Sbixia1 1502d3af6535SAart Bik // Implement as follows: 1503b86d3cbcSAart Bik // %reader = @createCheckedSparseTensorReader(%filename) 150484cd51bbSwren romano // %nse = @getSparseTensorNSE(%reader) 150584cd51bbSwren romano // %coo = bufferization.alloc_tensor an ordered COO with 150684cd51bbSwren romano // dst dim ordering, size_hint = %nse 150784cd51bbSwren romano // %coordinates = sparse_tensor.coordinates_buffer(%coo) 150884cd51bbSwren romano // %values = sparse_tensor.values(%coo) 150984cd51bbSwren romano // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values) 151084cd51bbSwren romano // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values) 15112c81d432Sbixia1 // update storage specifier 151284cd51bbSwren romano // @delSparseTensorReader(%reader) 151383cf0dc9SAart Bik SmallVector<Value> dimSizesValues; 1514d3af6535SAart Bik Value dimSizesBuffer; 1515d3af6535SAart Bik Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0], 151683cf0dc9SAart Bik dimSizesValues, dimSizesBuffer); 15172c81d432Sbixia1 1518d3af6535SAart Bik // Get the number of stored entries. 151984cd51bbSwren romano const Type indexTp = rewriter.getIndexType(); 1520d3af6535SAart Bik Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE", 1521d3af6535SAart Bik {indexTp}, {reader}, EmitCInterface::Off) 1522d3af6535SAart Bik .getResult(0); 15232c81d432Sbixia1 152483cf0dc9SAart Bik // Construct the lvl sizes and the dim2lvl/lvl2dim buffers. 15252323f48eSAart Bik SmallVector<Value> lvlSizesValues; 1526d3af6535SAart Bik Value dim2lvlBuffer; 1527d3af6535SAart Bik Value lvl2dimBuffer; 152883cf0dc9SAart Bik genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer, 15292323f48eSAart Bik lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer); 15302c81d432Sbixia1 153183cf0dc9SAart Bik // Construct allocation for each field. 153283cf0dc9SAart Bik Value sizeHint = nse; 153383cf0dc9SAart Bik SmallVector<Value> fields; 153483cf0dc9SAart Bik createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint, 153583cf0dc9SAart Bik lvlSizesValues, fields); 153683cf0dc9SAart Bik 153784cd51bbSwren romano // Read the COO tensor data. 1538160d483bSAart Bik MutSparseTensorDescriptor desc(dstTp, fields); 153984cd51bbSwren romano Value xs = desc.getAOSMemRef(); 154084cd51bbSwren romano Value ys = desc.getValMemRef(); 154184cd51bbSwren romano const Type boolTp = rewriter.getIntegerType(1); 154284cd51bbSwren romano const Type elemTp = dstTp.getElementType(); 154384cd51bbSwren romano const Type crdTp = dstTp.getCrdType(); 1544b86d3cbcSAart Bik SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers", 154584cd51bbSwren romano overheadTypeFunctionSuffix(crdTp), 154684cd51bbSwren romano primaryTypeFunctionSuffix(elemTp)}; 15472c81d432Sbixia1 Value isSorted = 154884cd51bbSwren romano createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp}, 1549d3af6535SAart Bik {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys}, 1550d3af6535SAart Bik EmitCInterface::On) 15512c81d432Sbixia1 .getResult(0); 15522c81d432Sbixia1 15532c81d432Sbixia1 // If the destination tensor is a sorted COO, we need to sort the COO tensor 15542c81d432Sbixia1 // data if the input elements aren't sorted yet. 1555d3af6535SAart Bik const Level lvlRank = dstTp.getLvlRank(); 155684cd51bbSwren romano if (dstTp.isOrderedLvl(lvlRank - 1)) { 155784cd51bbSwren romano Value kFalse = constantI1(rewriter, loc, false); 15582c81d432Sbixia1 Value notSorted = rewriter.create<arith::CmpIOp>( 155984cd51bbSwren romano loc, arith::CmpIPredicate::eq, isSorted, kFalse); 15602c81d432Sbixia1 scf::IfOp ifOp = 15612c81d432Sbixia1 rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false); 15622c81d432Sbixia1 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1563bfa3bc43SPeiming Liu auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank); 15640083f833SPeiming Liu rewriter.create<SortOp>(loc, nse, xs, ValueRange{ys}, xPerm, 1565bfa3bc43SPeiming Liu rewriter.getIndexAttr(0), 1566bfa3bc43SPeiming Liu SparseTensorSortKind::HybridQuickSort); 15672c81d432Sbixia1 rewriter.setInsertionPointAfter(ifOp); 15682c81d432Sbixia1 } 15692c81d432Sbixia1 157084cd51bbSwren romano // Set PosMemRef0[1] = nse. 157184cd51bbSwren romano const Value c1 = constantIndex(rewriter, loc, 1); 157284cd51bbSwren romano const Value posMemref0 = desc.getPosMemRef(0); 157384cd51bbSwren romano const Type posTp = dstTp.getPosType(); 157484cd51bbSwren romano const Value posNse = genCast(rewriter, loc, nse, posTp); 157584cd51bbSwren romano rewriter.create<memref::StoreOp>(loc, posNse, posMemref0, c1); 15762c81d432Sbixia1 15772c81d432Sbixia1 // Update storage specifier. 157884cd51bbSwren romano Value coordinatesSize = rewriter.create<arith::MulIOp>( 157984cd51bbSwren romano loc, nse, constantIndex(rewriter, loc, lvlRank)); 158084cd51bbSwren romano desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0, 158184cd51bbSwren romano coordinatesSize); 15822c81d432Sbixia1 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize, 15832c81d432Sbixia1 std::nullopt, nse); 15842c81d432Sbixia1 15852c81d432Sbixia1 // Release the sparse tensor reader. 15862c81d432Sbixia1 createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, 15872c81d432Sbixia1 EmitCInterface::Off); 15882c81d432Sbixia1 15892c81d432Sbixia1 // Replace operation with resulting memrefs. 1590aed43562SMatthias Springer rewriter.replaceOpWithMultiple(op, {fields}); 15912c81d432Sbixia1 return success(); 15922c81d432Sbixia1 } 15932c81d432Sbixia1 }; 15942c81d432Sbixia1 1595e8e8df4cSMatthias Springer struct SparseHasRuntimeLibraryConverter 1596e8e8df4cSMatthias Springer : public OpConversionPattern<HasRuntimeLibraryOp> { 1597e8e8df4cSMatthias Springer using OpConversionPattern::OpConversionPattern; 1598e8e8df4cSMatthias Springer LogicalResult 1599e8e8df4cSMatthias Springer matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor, 1600e8e8df4cSMatthias Springer ConversionPatternRewriter &rewriter) const override { 1601e8e8df4cSMatthias Springer auto i1Type = rewriter.getI1Type(); 1602e8e8df4cSMatthias Springer rewriter.replaceOpWithNewOp<arith::ConstantOp>( 1603e8e8df4cSMatthias Springer op, i1Type, rewriter.getIntegerAttr(i1Type, 0)); 1604e8e8df4cSMatthias Springer return success(); 1605e8e8df4cSMatthias Springer } 1606e8e8df4cSMatthias Springer }; 1607e8e8df4cSMatthias Springer 160886b22d31SAart Bik } // namespace 160986b22d31SAart Bik 161086b22d31SAart Bik //===----------------------------------------------------------------------===// 161186b22d31SAart Bik // Public method for populating conversion rules. 161286b22d31SAart Bik //===----------------------------------------------------------------------===// 161386b22d31SAart Bik 161486b22d31SAart Bik /// Populates the given patterns list with conversion rules required for 161586b22d31SAart Bik /// the sparsification of linear algebra operations. 16167276b643Sbixia1 void mlir::populateSparseTensorCodegenPatterns( 1617206fad0eSMatthias Springer const TypeConverter &typeConverter, RewritePatternSet &patterns, 1618c44d307cSPeiming Liu bool createSparseDeallocs, bool enableBufferInitialization) { 1619e8e8df4cSMatthias Springer patterns.add< 1620e8e8df4cSMatthias Springer SparseAssembleOpConverter, SparseDisassembleOpConverter, 1621c780352dSPeiming Liu SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter, 1622de560888SPeiming Liu SparseCastConverter, SparseExtractSliceConverter, 1623e8e8df4cSMatthias Springer SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter, 1624e8e8df4cSMatthias Springer SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter, 16256db397a8SPeiming Liu SparseSliceGetterOpConverter<ToSliceOffsetOp, 16266db397a8SPeiming Liu StorageSpecifierKind::DimOffset>, 16276db397a8SPeiming Liu SparseSliceGetterOpConverter<ToSliceStrideOp, 16286db397a8SPeiming Liu StorageSpecifierKind::DimStride>, 16296db397a8SPeiming Liu SparseToPositionsConverter, SparseToCoordinatesConverter, 16306db397a8SPeiming Liu SparseToCoordinatesBufferConverter, SparseToValuesConverter, 1631d3af6535SAart Bik SparseConvertConverter, SparseNewConverter, 1632e8e8df4cSMatthias Springer SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>( 1633e8e8df4cSMatthias Springer typeConverter, patterns.getContext()); 1634de560888SPeiming Liu patterns.add<SparseTensorDeallocConverter>( 1635c44d307cSPeiming Liu typeConverter, patterns.getContext(), createSparseDeallocs); 16363e4a8c2cSAart Bik patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>( 16373e4a8c2cSAart Bik typeConverter, patterns.getContext(), enableBufferInitialization); 163886b22d31SAart Bik } 1639