1 //===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A pass that converts sparse tensor types and primitives to actual compiler 10 // visible buffers and actual compiler IR that implements these primitives on 11 // the selected sparse tensor storage schemes. This pass provides an alternative 12 // to the SparseTensorConversion pass, eliminating the dependence on a runtime 13 // support library (other than for file I/O), and providing many more 14 // opportunities for subsequent compiler optimization of the generated code. 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "Utils/CodegenUtils.h" 19 #include "Utils/SparseTensorDescriptor.h" 20 21 #include "mlir/Dialect/Arith/Utils/Utils.h" 22 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 23 #include "mlir/Dialect/Func/IR/FuncOps.h" 24 #include "mlir/Dialect/Linalg/Utils/Utils.h" 25 #include "mlir/Dialect/MemRef/IR/MemRef.h" 26 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 27 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 29 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 30 #include "mlir/Dialect/Tensor/IR/Tensor.h" 31 #include "mlir/Transforms/DialectConversion.h" 32 33 #include <optional> 34 35 using namespace mlir; 36 using namespace mlir::sparse_tensor; 37 38 //===----------------------------------------------------------------------===// 39 // Helper methods. 40 //===----------------------------------------------------------------------===// 41 42 /// Flatten the given value ranges into a single vector of values. 43 static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { 44 SmallVector<Value> result; 45 for (const auto &vals : values) 46 llvm::append_range(result, vals); 47 return result; 48 } 49 50 /// Assert that the given value range contains a single value and return it. 51 static Value getSingleValue(ValueRange values) { 52 assert(values.size() == 1 && "expected single value"); 53 return values.front(); 54 } 55 56 /// Generates a load with proper `index` typing. 57 static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) { 58 idx = genCast(builder, loc, idx, builder.getIndexType()); 59 return builder.create<memref::LoadOp>(loc, mem, idx); 60 } 61 62 /// Generates a store with proper `index` typing and proper value. 63 static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, 64 Value idx) { 65 idx = genCast(builder, loc, idx, builder.getIndexType()); 66 val = genCast(builder, loc, val, 67 cast<ShapedType>(mem.getType()).getElementType()); 68 builder.create<memref::StoreOp>(loc, val, mem, idx); 69 } 70 71 /// Creates a straightforward counting for-loop. 72 static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, 73 MutableArrayRef<Value> fields, 74 Value lower = Value()) { 75 Type indexType = builder.getIndexType(); 76 if (!lower) 77 lower = constantZero(builder, loc, indexType); 78 Value one = constantOne(builder, loc, indexType); 79 scf::ForOp forOp = builder.create<scf::ForOp>(loc, lower, upper, one, fields); 80 for (unsigned i = 0, e = fields.size(); i < e; i++) 81 fields[i] = forOp.getRegionIterArg(i); 82 builder.setInsertionPointToStart(forOp.getBody()); 83 return forOp; 84 } 85 86 /// Creates a push back operation. 87 static void createPushback(OpBuilder &builder, Location loc, 88 MutSparseTensorDescriptor desc, 89 SparseTensorFieldKind kind, std::optional<Level> lvl, 90 Value value, Value repeat = Value()) { 91 Type etp = desc.getMemRefElementType(kind, lvl); 92 Value field = desc.getMemRefField(kind, lvl); 93 StorageSpecifierKind specFieldKind = toSpecifierKind(kind); 94 95 auto pushBackOp = builder.create<PushBackOp>( 96 loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field, 97 genCast(builder, loc, value, etp), repeat); 98 99 desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer()); 100 desc.setSpecifierField(builder, loc, specFieldKind, lvl, 101 pushBackOp.getNewSize()); 102 } 103 104 /// Generates code that allocates a sparse storage scheme for given rank. 105 static void allocSchemeForRank(OpBuilder &builder, Location loc, 106 MutSparseTensorDescriptor desc, Level startLvl) { 107 const SparseTensorType stt(desc.getRankedTensorType()); 108 Value linear = constantIndex(builder, loc, 1); 109 const Level lvlRank = stt.getLvlRank(); 110 for (Level lvl = startLvl; lvl < lvlRank; lvl++) { 111 const auto lt = stt.getLvlType(lvl); 112 if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { 113 // Append linear x positions, initialized to zero. Since each compressed 114 // dimension initially already has a single zero entry, this maintains 115 // the desired "linear + 1" length property at all times. For loose 116 // compression, we multiply linear by two in order to append both the 117 // lo/hi positions. 118 Value posZero = constantZero(builder, loc, stt.getPosType()); 119 if (isLooseCompressedLT(lt)) { 120 Value two = constantIndex(builder, loc, 2); 121 linear = builder.create<arith::MulIOp>(loc, linear, two); 122 } 123 createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl, 124 /*value=*/posZero, /*repeat=*/linear); 125 return; 126 } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) { 127 return; // nothing to do 128 } 129 // Keep compounding the size, but nothing needs to be initialized 130 // at this level. We will eventually reach a compressed level or 131 // otherwise the values array for the from-here "all-dense" case. 132 assert(isDenseLT(lt)); 133 Value size = desc.getLvlSize(builder, loc, lvl); 134 linear = builder.create<arith::MulIOp>(loc, linear, size); 135 } 136 // Reached values array so prepare for an insertion. 137 Value valZero = constantZero(builder, loc, stt.getElementType()); 138 createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef, 139 std::nullopt, /*value=*/valZero, /*repeat=*/linear); 140 } 141 142 /// Creates allocation operation. 143 static Value createAllocation(OpBuilder &builder, Location loc, 144 MemRefType memRefType, Value sz, 145 bool enableInit) { 146 Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz); 147 Type elemType = memRefType.getElementType(); 148 if (enableInit) { 149 Value fillValue = constantZero(builder, loc, elemType); 150 builder.create<linalg::FillOp>(loc, fillValue, buffer); 151 } 152 return buffer; 153 } 154 155 /// Creates the dim sizes array, filling in from dynamic sizes. 156 static void createDimSizes(OpBuilder &builder, Location loc, 157 SparseTensorType stt, ValueRange dynSizes, 158 /*out*/ SmallVectorImpl<Value> &dimSizesValues) { 159 const Dimension dimRank = stt.getDimRank(); 160 dimSizesValues.clear(); 161 dimSizesValues.reserve(dimRank); 162 unsigned i = 0; 163 for (const Size sz : stt.getDimShape()) 164 dimSizesValues.push_back(ShapedType::isDynamic(sz) 165 ? dynSizes[i++] 166 : constantIndex(builder, loc, sz)); 167 } 168 169 /// Creates allocation for each field in sparse tensor type. Note that 170 /// for all dynamic memrefs in the sparse tensor stroage layout, the 171 /// memory size is really the capacity of the "vector", while the actual 172 /// size resides in the sizes array. 173 static void createAllocFields(OpBuilder &builder, Location loc, 174 SparseTensorType stt, bool enableInit, 175 Value sizeHint, 176 SmallVectorImpl<Value> &lvlSizesValues, 177 /*out*/ SmallVectorImpl<Value> &fields) { 178 Level lvlRank = stt.getLvlRank(); 179 // Set up some heuristic sizes. We try to set the initial 180 // size based on available information. Otherwise we just 181 // initialize a few elements to start the reallocation chain. 182 // TODO: refine this 183 Value posHeuristic, crdHeuristic, valHeuristic; 184 if (stt.isAllDense()) { 185 valHeuristic = lvlSizesValues[0]; 186 for (Level lvl = 1; lvl < lvlRank; lvl++) 187 valHeuristic = 188 builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]); 189 } else if (sizeHint) { 190 if (stt.getAoSCOOStart() == 0) { 191 posHeuristic = constantIndex(builder, loc, 2); 192 crdHeuristic = builder.create<arith::MulIOp>( 193 loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS 194 } else if (lvlRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) { 195 posHeuristic = builder.create<arith::AddIOp>( 196 loc, sizeHint, constantIndex(builder, loc, 1)); 197 crdHeuristic = sizeHint; 198 } else { 199 posHeuristic = crdHeuristic = constantIndex(builder, loc, 16); 200 } 201 valHeuristic = sizeHint; 202 } else { 203 posHeuristic = crdHeuristic = valHeuristic = 204 constantIndex(builder, loc, 16); 205 } 206 // Initializes all fields. An initial storage specifier and allocated 207 // positions/coordinates/values memrefs (with heuristic capacity). 208 foreachFieldAndTypeInSparseTensor( 209 stt, 210 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic, 211 enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, 212 Level /*lvl*/, LevelType /*lt*/) -> bool { 213 assert(fields.size() == fIdx); 214 Value field; 215 switch (fKind) { 216 case SparseTensorFieldKind::StorageSpec: 217 field = SparseTensorSpecifier::getInitValue(builder, loc, stt); 218 break; 219 case SparseTensorFieldKind::PosMemRef: 220 field = createAllocation(builder, loc, cast<MemRefType>(fType), 221 posHeuristic, enableInit); 222 break; 223 case SparseTensorFieldKind::CrdMemRef: 224 field = createAllocation(builder, loc, cast<MemRefType>(fType), 225 crdHeuristic, enableInit); 226 break; 227 case SparseTensorFieldKind::ValMemRef: 228 field = createAllocation(builder, loc, cast<MemRefType>(fType), 229 valHeuristic, enableInit); 230 break; 231 } 232 assert(field); 233 fields.push_back(field); 234 // Returns true to continue the iteration. 235 return true; 236 }); 237 // Initialize the storage scheme to an empty tensor. Sets the lvlSizes 238 // and gives all position fields an initial zero entry, so that it is 239 // easier to maintain the "linear + 1" length property. 240 MutSparseTensorDescriptor desc(stt, fields); 241 Value posZero = constantZero(builder, loc, stt.getPosType()); 242 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { 243 desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]); 244 const auto lt = stt.getLvlType(lvl); 245 if (isCompressedLT(lt) || isLooseCompressedLT(lt)) 246 createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl, 247 /*value=*/posZero); 248 } 249 allocSchemeForRank(builder, loc, desc, /*rank=*/0); 250 } 251 252 /// Helper method that generates block specific to compressed case: 253 /// 254 /// // given: parentPos = posCursor[lvl-1] 255 /// pstart = desc.positions[lvl][parentPos] 256 /// pstop = desc.positions[lvl][parentPos+1] 257 /// plast = pstop - 1 258 /// msz = desc.coordinates[lvl].size() 259 /// if (pstart < pstop) { 260 /// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl]) 261 /// } else { // first insertion 262 /// isPresent = false 263 /// desc.positions[lvl][parentPos] = msz 264 /// } 265 /// if (isPresent) { // coordinate is already present 266 /// pnext = plast 267 /// } else { 268 /// desc.coordinates[lvl].push_back(lvlCoords[lvl]) 269 /// desc.positions[lvl][parentPos+1] = msz+1 270 /// pnext = msz 271 /// <prepare level lvl+1> 272 /// } 273 /// posCursor[lvl] = pnext 274 static Value genCompressed(OpBuilder &builder, Location loc, 275 MutSparseTensorDescriptor desc, ValueRange lvlCoords, 276 Value /*unused*/, Value parentPos, Level lvl) { 277 const SparseTensorType stt(desc.getRankedTensorType()); 278 const Level lvlRank = stt.getLvlRank(); 279 assert(lvl < lvlRank && "Level is out of bounds"); 280 assert(lvlCoords.size() == static_cast<size_t>(lvlRank) && 281 "Level-rank mismatch"); 282 SmallVector<Type> types; 283 Type indexType = builder.getIndexType(); 284 Type boolType = builder.getIntegerType(1); 285 unsigned crdFidx; 286 unsigned crdStride; 287 std::tie(crdFidx, crdStride) = desc.getCrdMemRefIndexAndStride(lvl); 288 const Value one = constantIndex(builder, loc, 1); 289 const Value pp1 = builder.create<arith::AddIOp>(loc, parentPos, one); 290 const Value positionsAtLvl = desc.getPosMemRef(lvl); 291 const Value pstart = genLoad(builder, loc, positionsAtLvl, parentPos); 292 const Value pstop = genLoad(builder, loc, positionsAtLvl, pp1); 293 const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl); 294 const Value crdStrideC = 295 crdStride > 1 ? constantIndex(builder, loc, crdStride) : Value(); 296 const Value msz = 297 crdStrideC ? builder.create<arith::DivUIOp>(loc, crdMsz, crdStrideC) 298 : crdMsz; 299 const Value plast = builder.create<arith::SubIOp>( 300 loc, genCast(builder, loc, pstop, indexType), one); 301 // Conditional expression. 302 Value lt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 303 pstart, pstop); 304 types.push_back(boolType); 305 scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true); 306 types.pop_back(); 307 builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); 308 Value crd = 309 genLoad(builder, loc, desc.getMemRefField(crdFidx), 310 crdStrideC ? builder.create<arith::MulIOp>(loc, plast, crdStrideC) 311 : plast); 312 Value eq = builder.create<arith::CmpIOp>( 313 loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType), 314 lvlCoords[lvl]); 315 builder.create<scf::YieldOp>(loc, eq); 316 builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); 317 if (lvl > 0) 318 genStore(builder, loc, msz, positionsAtLvl, parentPos); 319 builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false)); 320 builder.setInsertionPointAfter(ifOp1); 321 // If present construct. Note that for a non-unique dimension level, we 322 // simply set the condition to false and rely on CSE/DCE to clean up the IR. 323 // 324 // TODO: generate less temporary IR? 325 // 326 for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) 327 types.push_back(desc.getField(i).getType()); 328 types.push_back(indexType); 329 const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0) 330 : constantI1(builder, loc, false); 331 scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true); 332 // If present (fields unaffected, update pnext to plast). 333 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 334 335 // FIXME: This does not looks like a clean way, but probably the most 336 // efficient way. 337 desc.getFields().push_back(plast); 338 builder.create<scf::YieldOp>(loc, desc.getFields()); 339 desc.getFields().pop_back(); 340 341 // If !present (changes fields, update pnext). 342 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 343 Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one); 344 genStore(builder, loc, mszp1, positionsAtLvl, pp1); 345 createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl, 346 /*value=*/lvlCoords[lvl]); 347 // Prepare the next level "as needed". 348 if ((lvl + 1) < lvlRank) 349 allocSchemeForRank(builder, loc, desc, lvl + 1); 350 351 desc.getFields().push_back(msz); 352 builder.create<scf::YieldOp>(loc, desc.getFields()); 353 desc.getFields().pop_back(); 354 355 // Update fields and return next pos. 356 builder.setInsertionPointAfter(ifOp2); 357 unsigned o = 0; 358 for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) 359 desc.setField(i, ifOp2.getResult(o++)); 360 return ifOp2.getResult(o); 361 } 362 363 /// Generates insertion finalization code. 364 static void genEndInsert(OpBuilder &builder, Location loc, 365 SparseTensorDescriptor desc) { 366 const SparseTensorType stt(desc.getRankedTensorType()); 367 const Level lvlRank = stt.getLvlRank(); 368 for (Level lvl = 0; lvl < lvlRank; lvl++) { 369 const auto lt = stt.getLvlType(lvl); 370 if (isCompressedLT(lt)) { 371 // Compressed dimensions need a position cleanup for all entries 372 // that were not visited during the insertion pass. 373 // 374 // TODO: avoid cleanup and keep compressed scheme consistent at all 375 // times? 376 // 377 if (lvl > 0) { 378 Type posType = stt.getPosType(); 379 Value posMemRef = desc.getPosMemRef(lvl); 380 Value hi = desc.getPosMemSize(builder, loc, lvl); 381 Value zero = constantIndex(builder, loc, 0); 382 Value one = constantIndex(builder, loc, 1); 383 // Vector of only one, but needed by createFor's prototype. 384 SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)}; 385 scf::ForOp loop = createFor(builder, loc, hi, inits, one); 386 Value i = loop.getInductionVar(); 387 Value oldv = loop.getRegionIterArg(0); 388 Value newv = genLoad(builder, loc, posMemRef, i); 389 Value posZero = constantZero(builder, loc, posType); 390 Value cond = builder.create<arith::CmpIOp>( 391 loc, arith::CmpIPredicate::eq, newv, posZero); 392 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType), 393 cond, /*else*/ true); 394 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 395 genStore(builder, loc, oldv, posMemRef, i); 396 builder.create<scf::YieldOp>(loc, oldv); 397 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 398 builder.create<scf::YieldOp>(loc, newv); 399 builder.setInsertionPointAfter(ifOp); 400 builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); 401 builder.setInsertionPointAfter(loop); 402 } 403 } else { 404 assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) || 405 isNOutOfMLT(lt)); 406 } 407 } 408 } 409 410 /// Generates a subview into the sizes. 411 static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, 412 Value sz) { 413 auto memTp = llvm::cast<MemRefType>(mem.getType()); 414 // For higher-dimensional memrefs, we assume that the innermost 415 // dimension is always of the right size. 416 // TODO: generate complex truncating view here too? 417 if (memTp.getRank() > 1) 418 return mem; 419 // Truncate linear memrefs to given size. 420 return builder 421 .create<memref::SubViewOp>( 422 loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()), 423 mem, ValueRange{}, ValueRange{sz}, ValueRange{}, 424 ArrayRef<int64_t>{0}, // static offset 425 ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size 426 ArrayRef<int64_t>{1}) // static stride 427 .getResult(); 428 } 429 430 /// Creates the reassociation array. 431 static SmallVector<ReassociationIndices> 432 getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) { 433 SmallVector<ReassociationIndices> ret(batchLvls + 1, {}); 434 // Create reassociation in the form: 435 // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank} 436 for (unsigned i = 0; i < batchLvls; i++) 437 ret[i].push_back(i); 438 439 for (int i = batchLvls, e = srcTp.getRank(); i < e; i++) 440 ret.back().push_back(i); 441 442 return ret; 443 } 444 445 //===----------------------------------------------------------------------===// 446 // Codegen rules. 447 //===----------------------------------------------------------------------===// 448 449 namespace { 450 451 /// Helper class to help lowering sparse_tensor.insert operation. 452 class SparseInsertGenerator 453 : public FuncCallOrInlineGenerator<SparseInsertGenerator> { 454 public: 455 SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params, 456 bool genCall) 457 : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){}; 458 459 /// Generates code along an insertion path without the need for a "cursor". 460 /// This current insertion strategy comes at the expense of some testing 461 /// overhead for each insertion. The strategy will be optimized later for 462 /// common insertion patterns. The current insertion strategy also assumes 463 /// insertions occur in "a reasonable order" that enables building the 464 /// storage scheme in an appending/inserting kind of fashion (i.e. no 465 /// in-between insertions that need data movement). The implementation 466 /// relies on CSE/DCE to clean up all bookkeeping that is not needed. 467 /// 468 /// TODO: better unord/not-unique; also generalize, optimize, specialize! 469 SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args, 470 OpBuilder &builder, Location loc) { 471 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp)); 472 const Level lvlRank = stt.getLvlRank(); 473 // Extract fields and coordinates from args. 474 SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1)); 475 MutSparseTensorDescriptor desc(stt, fields); 476 const SmallVector<Value> coords = 477 llvm::to_vector(args.take_back(lvlRank + 1).drop_back()); 478 Value value = args.back(); 479 Value parentPos = constantZero(builder, loc, builder.getIndexType()); 480 // Generate code for every level. 481 for (Level lvl = 0; lvl < lvlRank; lvl++) { 482 const auto lt = stt.getLvlType(lvl); 483 if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { 484 // Create: 485 // if (!present) { 486 // coordinates[lvl].push_back(coords[lvl]) 487 // <update positions and prepare level lvl + 1> 488 // } 489 // positions[lvl] = coordinates.size() - 1 490 // <insert @ positions[lvl] at next level lvl + 1> 491 if (isLooseCompressedLT(lt)) { 492 Value two = constantIndex(builder, loc, 2); 493 parentPos = builder.create<arith::MulIOp>(loc, parentPos, two); 494 } 495 parentPos = 496 genCompressed(builder, loc, desc, coords, value, parentPos, lvl); 497 } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) { 498 // Create: 499 // coordinates[lvl].push_back(coords[lvl]) 500 // positions[lvl] = positions[lvl-1] 501 // <insert @ positions[lvl] at next level lvl + 1> 502 createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, 503 lvl, /*value=*/coords[lvl]); 504 } else { 505 assert(isDenseLT(lt)); 506 // Construct the new position as: 507 // positions[lvl] = size * positions[lvl-1] + coords[lvl] 508 // <insert @ positions[lvl] at next level lvl + 1> 509 Value size = desc.getLvlSize(builder, loc, lvl); 510 Value mult = builder.create<arith::MulIOp>(loc, size, parentPos); 511 parentPos = builder.create<arith::AddIOp>(loc, mult, coords[lvl]); 512 } 513 } 514 // Reached the actual value append/insert. 515 if (!stt.isDenseLvl(lvlRank - 1)) 516 createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef, 517 std::nullopt, value); 518 else 519 genStore(builder, loc, value, desc.getValMemRef(), parentPos); 520 return fields; 521 } 522 523 std::string getMangledFuncName() { 524 // The mangled name of the function has this format: 525 // <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth> 526 constexpr const char kInsertFuncNamePrefix[] = "_insert_"; 527 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp)); 528 SmallString<32> nameBuffer; 529 llvm::raw_svector_ostream nameOstream(nameBuffer); 530 nameOstream << kInsertFuncNamePrefix; 531 const Level lvlRank = stt.getLvlRank(); 532 for (Level l = 0; l < lvlRank; l++) { 533 std::string lvlType = toMLIRString(stt.getLvlType(l)); 534 // Replace/remove punctuations in level properties. 535 std::replace_if( 536 lvlType.begin(), lvlType.end(), 537 [](char c) { return c == '(' || c == ','; }, '_'); 538 llvm::erase_if(lvlType, [](char c) { return c == ')' || c == ' '; }); 539 nameOstream << lvlType << "_"; 540 } 541 // Static dim sizes are used in the generated code while dynamic sizes are 542 // loaded from the dimSizes buffer. This is the reason for adding the shape 543 // to the function name. 544 for (const auto sz : stt.getDimShape()) 545 nameOstream << sz << "_"; 546 // Permutation information is also used in generating insertion. 547 if (!stt.isIdentity()) 548 nameOstream << stt.getDimToLvl() << "_"; 549 nameOstream << stt.getElementType() << "_"; 550 nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth(); 551 return nameOstream.str().str(); 552 } 553 554 private: 555 TensorType rtp; 556 }; 557 558 /// Sparse tensor storage conversion rule for returns. 559 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { 560 public: 561 using OpConversionPattern::OpConversionPattern; 562 LogicalResult 563 matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor, 564 ConversionPatternRewriter &rewriter) const override { 565 // Create a return with the flattened value extracted from sparse tensors. 566 rewriter.replaceOpWithNewOp<func::ReturnOp>( 567 op, flattenValues(adaptor.getOperands())); 568 return success(); 569 } 570 }; 571 572 /// Sparse tensor storage conversion rule for calls. 573 class SparseCallConverter : public OpConversionPattern<func::CallOp> { 574 public: 575 // The default CallOp converter can not handle 1:N type conversion. 576 using OpConversionPattern::OpConversionPattern; 577 LogicalResult 578 matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor, 579 ConversionPatternRewriter &rewriter) const override { 580 Location loc = op.getLoc(); 581 // In case of: 582 // sparse_tensor, f, sparse_tensor = call @foo(...) 583 // ==> 584 // memref..., f, memref = call @foo(...) replace with 585 // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor 586 SmallVector<Type> finalRetTy; 587 if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy))) 588 return failure(); 589 590 // (1) Generates new call with flattened return value. 591 auto newCall = rewriter.create<func::CallOp>( 592 loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands())); 593 // (2) Gather sparse tensor returns. 594 SmallVector<SmallVector<Value>> packedResultVals; 595 // Tracks the offset of current return value (of the original call) 596 // relative to the new call (after sparse tensor flattening); 597 unsigned retOffset = 0; 598 // Temporal buffer to hold the flattened list of type for 599 // a sparse tensor. 600 SmallVector<Type> sparseFlat; 601 for (auto ret : op.getResults()) { 602 assert(retOffset < newCall.getNumResults()); 603 auto retType = ret.getType(); 604 if (failed(typeConverter->convertType(retType, sparseFlat))) 605 llvm_unreachable("Failed to convert type in sparse tensor codegen"); 606 607 // Converted types can not be empty when the type conversion succeed. 608 assert(!sparseFlat.empty()); 609 if (sparseFlat.size() > 1) { 610 auto flatSize = sparseFlat.size(); 611 packedResultVals.emplace_back(); 612 llvm::append_range(packedResultVals.back(), 613 newCall.getResults().slice(retOffset, flatSize)); 614 retOffset += flatSize; 615 } else { 616 // If this is an 1:1 conversion, no need for casting. 617 packedResultVals.emplace_back(); 618 packedResultVals.back().push_back(newCall.getResult(retOffset)); 619 retOffset++; 620 } 621 sparseFlat.clear(); 622 } 623 624 assert(packedResultVals.size() == op.getNumResults()); 625 rewriter.replaceOpWithMultiple( 626 op, llvm::to_vector_of<ValueRange>(packedResultVals)); 627 return success(); 628 } 629 }; 630 631 /// Sparse codegen rule for level accesses. 632 class SparseLvlOpConverter : public OpConversionPattern<LvlOp> { 633 public: 634 using OpConversionPattern::OpConversionPattern; 635 LogicalResult 636 matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor, 637 ConversionPatternRewriter &rewriter) const override { 638 std::optional<int64_t> lvl = op.getConstantLvlIndex(); 639 RankedTensorType srcType = op.getSource().getType(); 640 if (!lvl || !getSparseTensorEncoding(srcType)) 641 return failure(); 642 643 auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType); 644 auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl); 645 646 rewriter.replaceOp(op, sz); 647 return success(); 648 } 649 }; 650 651 // TODO: use a new SortCOO operation here instead of reusing convert op. 652 struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> { 653 using OpConversionPattern::OpConversionPattern; 654 LogicalResult 655 matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor, 656 ConversionPatternRewriter &rewriter) const override { 657 Location loc = op.getLoc(); 658 MLIRContext *ctx = op.getContext(); 659 660 SparseTensorType srcStt = getSparseTensorType(op.getInputCoo()); 661 SparseTensorType dstStt = getSparseTensorType(op.getResultCoo()); 662 663 // Should have been verified. 664 assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() && 665 dstStt.isCOOType() && srcStt.isCOOType()); 666 assert(dstStt.hasSameDimToLvl(srcStt)); 667 668 // We don't need a mutable descriptor here as we perform sorting in-place. 669 auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(), 670 op.getInputCoo().getType()); 671 auto nnz = desc.getValMemSize(rewriter, op.getLoc()); 672 auto crd = desc.getAOSMemRef(); 673 auto val = desc.getValMemRef(); 674 675 // Otherwise we need another data shuffle and a non-identity map. 676 assert(dstStt.hasSameDimToLvl(srcStt)); 677 (void)dstStt; // to silence warning when assertion is disabled 678 679 auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx); 680 681 rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id, 682 rewriter.getIndexAttr(0), op.getAlgorithm()); 683 684 // Since we do in-place sorting, the destinate tensor will have the same set 685 // of memrefs as the source tensor. 686 rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()}); 687 return success(); 688 } 689 }; 690 691 template <typename Op, StorageSpecifierKind kind> 692 class SparseSliceGetterOpConverter : public OpConversionPattern<Op> { 693 public: 694 using OpConversionPattern<Op>::OpConversionPattern; 695 using typename OpConversionPattern<Op>::OneToNOpAdaptor; 696 697 LogicalResult 698 matchAndRewrite(Op op, OneToNOpAdaptor adaptor, 699 ConversionPatternRewriter &rewriter) const override { 700 // Simply lowers to specifer.get <field> operation. 701 auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(), 702 op.getSlice().getType()); 703 auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind, 704 op.getDim().getZExtValue()); 705 706 rewriter.replaceOp(op, v); 707 return success(); 708 } 709 }; 710 711 /// Sparse codegen rule for trivial tensor casts. 712 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { 713 public: 714 using OpConversionPattern::OpConversionPattern; 715 LogicalResult 716 matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor, 717 ConversionPatternRewriter &rewriter) const override { 718 // Only rewrite identically annotated source/dest. 719 auto encDst = getSparseTensorEncoding(op.getType()); 720 auto encSrc = getSparseTensorEncoding(op.getSource().getType()); 721 if (!encDst || encDst != encSrc) 722 return failure(); 723 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); 724 return success(); 725 } 726 }; 727 728 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> { 729 public: 730 using OpConversionPattern::OpConversionPattern; 731 LogicalResult 732 matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor, 733 ConversionPatternRewriter &rewriter) const override { 734 // Simply fold the operation. 735 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); 736 return success(); 737 } 738 }; 739 740 /// Sparse codegen rule for the alloc operator. 741 class SparseTensorAllocConverter 742 : public OpConversionPattern<bufferization::AllocTensorOp> { 743 public: 744 using OpConversionPattern::OpConversionPattern; 745 SparseTensorAllocConverter(const TypeConverter &typeConverter, 746 MLIRContext *context, bool enableInit) 747 : OpConversionPattern(typeConverter, context), 748 enableBufferInitialization(enableInit) {} 749 750 LogicalResult 751 matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor, 752 ConversionPatternRewriter &rewriter) const override { 753 const auto resType = getSparseTensorType(op); 754 if (!resType.hasEncoding()) 755 return failure(); 756 757 Location loc = op.getLoc(); 758 // Deal with copy. 759 if (op.getCopy()) { 760 auto desc = getDescriptorFromTensorTuple( 761 adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType())); 762 SmallVector<Value> fields; 763 fields.reserve(desc.getNumFields()); 764 // Memcpy on memref fields. 765 for (auto field : desc.getMemRefFields()) { 766 auto memrefTp = cast<MemRefType>(field.getType()); 767 auto size = rewriter.create<memref::DimOp>(loc, field, 0); 768 auto copied = 769 rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size}); 770 rewriter.create<memref::CopyOp>(loc, field, copied); 771 fields.push_back(copied); 772 } 773 // Reuses specifier. 774 fields.push_back(desc.getSpecifier()); 775 assert(fields.size() == desc.getNumFields()); 776 rewriter.replaceOpWithMultiple(op, {fields}); 777 return success(); 778 } 779 780 if (!resType.isIdentity()) { 781 return rewriter.notifyMatchFailure( 782 op, "try run --sparse-reinterpret-map before codegen"); 783 } 784 // Level size equals to dimension size since lvl2dim map is an identity map. 785 SmallVector<Value> lvlSizesValues; 786 createDimSizes(rewriter, loc, resType, 787 flattenValues(adaptor.getDynamicSizes()), 788 /*dimSizesValues=*/lvlSizesValues); 789 790 // Construct allocation for each field. 791 Value sizeHint = op.getSizeHint(); 792 SmallVector<Value> fields; 793 createAllocFields(rewriter, loc, resType, enableBufferInitialization, 794 sizeHint, lvlSizesValues, fields); 795 796 // Replace operation with resulting memrefs. 797 rewriter.replaceOpWithMultiple(op, {fields}); 798 return success(); 799 } 800 801 private: 802 bool enableBufferInitialization; 803 }; 804 805 /// Sparse codegen rule for the empty tensor operator. 806 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { 807 public: 808 using OpConversionPattern::OpConversionPattern; 809 SparseTensorEmptyConverter(const TypeConverter &typeConverter, 810 MLIRContext *context, bool enableInit) 811 : OpConversionPattern(typeConverter, context), 812 enableBufferInitialization(enableInit) {} 813 814 LogicalResult 815 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, 816 ConversionPatternRewriter &rewriter) const override { 817 const auto resType = getSparseTensorType(op); 818 if (!resType.hasEncoding()) 819 return failure(); 820 821 if (!resType.isIdentity()) { 822 return rewriter.notifyMatchFailure( 823 op, "try run --sparse-reinterpret-map before codegen"); 824 } 825 826 Location loc = op.getLoc(); 827 // Level size equals to dimension size since lvl2dim map is an identity map. 828 SmallVector<Value> lvlSizesValues; 829 createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(), 830 /*dimSizesValues=*/lvlSizesValues); 831 // Construct allocation for each field. 832 Value sizeHint; // none 833 SmallVector<Value> fields; 834 createAllocFields(rewriter, loc, resType, enableBufferInitialization, 835 sizeHint, lvlSizesValues, fields); 836 837 // Replace operation with resulting memrefs. 838 rewriter.replaceOpWithMultiple(op, {fields}); 839 return success(); 840 } 841 842 private: 843 bool enableBufferInitialization; 844 }; 845 846 /// Sparse codegen rule for the dealloc operator. 847 class SparseTensorDeallocConverter 848 : public OpConversionPattern<bufferization::DeallocTensorOp> { 849 public: 850 using OpConversionPattern::OpConversionPattern; 851 SparseTensorDeallocConverter(const TypeConverter &typeConverter, 852 MLIRContext *context, bool createDeallocs) 853 : OpConversionPattern(typeConverter, context), 854 createDeallocs(createDeallocs) {} 855 856 LogicalResult 857 matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor, 858 ConversionPatternRewriter &rewriter) const override { 859 auto enc = getSparseTensorEncoding(op.getTensor().getType()); 860 if (!enc) 861 return failure(); 862 863 // If user requests not to deallocate sparse tensors, simply erase the 864 // operation. 865 if (createDeallocs) { 866 // Replace the sparse tensor deallocation with field deallocations. 867 Location loc = op.getLoc(); 868 auto desc = getDescriptorFromTensorTuple( 869 adaptor.getTensor(), 870 cast<RankedTensorType>(op.getTensor().getType())); 871 for (auto input : desc.getMemRefFields()) 872 // Deallocate every buffer used to store the sparse tensor handler. 873 rewriter.create<memref::DeallocOp>(loc, input); 874 } 875 rewriter.eraseOp(op); 876 return success(); 877 } 878 879 private: 880 const bool createDeallocs; 881 }; 882 883 /// Sparse codegen rule for tensor rematerialization. 884 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { 885 public: 886 using OpConversionPattern::OpConversionPattern; 887 LogicalResult 888 matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor, 889 ConversionPatternRewriter &rewriter) const override { 890 // Prepare descriptor. 891 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 892 op.getTensor().getType()); 893 // Generate optional insertion finalization code. 894 if (op.getHasInserts()) 895 genEndInsert(rewriter, op.getLoc(), desc); 896 // Replace operation with resulting memrefs. 897 rewriter.replaceOpWithMultiple(op, {desc.getFields()}); 898 return success(); 899 } 900 }; 901 902 /// Sparse codegen rule for the expand op. 903 class SparseExpandConverter : public OpConversionPattern<ExpandOp> { 904 public: 905 using OpConversionPattern::OpConversionPattern; 906 LogicalResult 907 matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor, 908 ConversionPatternRewriter &rewriter) const override { 909 if (!getSparseTensorEncoding(op.getTensor().getType())) 910 return failure(); 911 Location loc = op->getLoc(); 912 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 913 op.getTensor().getType()); 914 const auto srcType = getSparseTensorType(op.getTensor()); 915 Type eltType = srcType.getElementType(); 916 Type boolType = rewriter.getIntegerType(1); 917 Type idxType = rewriter.getIndexType(); 918 // All initialization should be done on entry of the loop nest. 919 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); 920 921 // Determine the size for access expansion (always the innermost stored 922 // level size). 923 const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1); 924 // Generate a memref for `sz` elements of type `t`. 925 const auto genAlloc = [&](Type t) { 926 const auto memTp = MemRefType::get({ShapedType::kDynamic}, t); 927 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); 928 }; 929 // Allocate temporary buffers for values/filled-switch and added. 930 // We do not use stack buffers for this, since the expanded size may 931 // be rather large (as it envelops a single expanded dense dimension). 932 Value values = genAlloc(eltType); 933 Value filled = genAlloc(boolType); 934 Value added = genAlloc(idxType); 935 Value zero = constantZero(rewriter, loc, idxType); 936 // Reset the values/filled-switch to all-zero/false. Note that this 937 // introduces an O(N) operation into the computation, but this reset 938 // operation is amortized over the innermost loops for the access 939 // pattern expansion. As noted in the operation doc, we would like 940 // to amortize this setup cost even between kernels. 941 rewriter.create<linalg::FillOp>( 942 loc, ValueRange{constantZero(rewriter, loc, eltType)}, 943 ValueRange{values}); 944 rewriter.create<linalg::FillOp>( 945 loc, ValueRange{constantZero(rewriter, loc, boolType)}, 946 ValueRange{filled}); 947 // Replace expansion op with these buffers and initial coordinate. 948 assert(op.getNumResults() == 4); 949 rewriter.replaceOp(op, {values, filled, added, zero}); 950 return success(); 951 } 952 }; 953 954 /// Sparse codegen rule for the compress operator. 955 class SparseCompressConverter : public OpConversionPattern<CompressOp> { 956 public: 957 using OpConversionPattern::OpConversionPattern; 958 LogicalResult 959 matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor, 960 ConversionPatternRewriter &rewriter) const override { 961 Location loc = op->getLoc(); 962 SmallVector<Value> fields; 963 auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields, 964 op.getTensor().getType()); 965 Value values = getSingleValue(adaptor.getValues()); 966 Value filled = getSingleValue(adaptor.getFilled()); 967 Value added = getSingleValue(adaptor.getAdded()); 968 Value count = getSingleValue(adaptor.getCount()); 969 const SparseTensorType dstType(desc.getRankedTensorType()); 970 Type eltType = dstType.getElementType(); 971 972 // If the innermost level is ordered, we need to sort the coordinates 973 // in the "added" array prior to applying the compression. 974 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1)) 975 rewriter.create<SortOp>( 976 loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1), 977 rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort); 978 // While performing the insertions, we also need to reset the elements 979 // of the values/filled-switch by only iterating over the set elements, 980 // to ensure that the runtime complexity remains proportional to the 981 // sparsity of the expanded access pattern. 982 // 983 // Generate 984 // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) { 985 // crd = added[i]; 986 // value = values[crd]; 987 // insert({lvlCoords, crd}, value); 988 // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value); 989 // values[crd] = 0; 990 // filled[crd] = false; 991 // yield new_memrefs 992 // } 993 scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields()); 994 Value i = loop.getInductionVar(); 995 996 Value crd = genLoad(rewriter, loc, added, i); 997 Value value = genLoad(rewriter, loc, values, crd); 998 SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end()); 999 SmallVector<Type> flatSpTensorTps = llvm::to_vector( 1000 llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); })); 1001 SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords()); 1002 params.append(flatLvlCoords.begin(), flatLvlCoords.end()); 1003 params.push_back(crd); 1004 params.push_back(value); 1005 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps, 1006 params, /*genCall=*/true); 1007 SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc); 1008 genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd); 1009 genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd); 1010 rewriter.create<scf::YieldOp>(loc, insertRet); 1011 1012 rewriter.setInsertionPointAfter(loop); 1013 // Deallocate the buffers on exit of the full loop nest. 1014 Operation *parent = getTop(op); 1015 rewriter.setInsertionPointAfter(parent); 1016 rewriter.create<memref::DeallocOp>(loc, values); 1017 rewriter.create<memref::DeallocOp>(loc, filled); 1018 rewriter.create<memref::DeallocOp>(loc, added); 1019 // Replace operation with resulting memrefs. 1020 rewriter.replaceOpWithMultiple(op, {loop->getResults()}); 1021 return success(); 1022 } 1023 }; 1024 1025 /// Sparse codegen rule for the insert operator. 1026 class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> { 1027 public: 1028 using OpConversionPattern::OpConversionPattern; 1029 LogicalResult 1030 matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor, 1031 ConversionPatternRewriter &rewriter) const override { 1032 auto stt = getSparseTensorType(op.getDest()); 1033 if (!stt.hasEncoding()) 1034 return failure(); 1035 assert(stt.isIdentity() && "Run reinterpret-map before conversion."); 1036 1037 Location loc = op.getLoc(); 1038 auto desc = 1039 getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType()); 1040 TypeRange flatSpTensorTps = desc.getFields().getTypes(); 1041 SmallVector<Value> params = llvm::to_vector(desc.getFields()); 1042 SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices()); 1043 params.append(flatIndices.begin(), flatIndices.end()); 1044 params.push_back(getSingleValue(adaptor.getScalar())); 1045 SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps, 1046 params, /*genCall=*/true); 1047 SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc); 1048 // Replace operation with resulting memrefs. 1049 rewriter.replaceOpWithMultiple(op, {ret}); 1050 return success(); 1051 } 1052 }; 1053 1054 /// Sparse codegen rule for position accesses. 1055 class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> { 1056 public: 1057 using OpAdaptor = typename ToPositionsOp::Adaptor; 1058 using OpConversionPattern<ToPositionsOp>::OpConversionPattern; 1059 LogicalResult 1060 matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor, 1061 ConversionPatternRewriter &rewriter) const override { 1062 // Replace the requested position access with corresponding field. 1063 // The view is restricted to the actual size to ensure clients 1064 // of this operation truly observe size, not capacity! 1065 Location loc = op.getLoc(); 1066 Level lvl = op.getLevel(); 1067 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 1068 op.getTensor().getType()); 1069 auto mem = desc.getPosMemRef(lvl); 1070 auto size = desc.getPosMemSize(rewriter, loc, lvl); 1071 rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); 1072 return success(); 1073 } 1074 }; 1075 1076 /// Sparse codegen rule for accessing the coordinates arrays. 1077 class SparseToCoordinatesConverter 1078 : public OpConversionPattern<ToCoordinatesOp> { 1079 public: 1080 using OpAdaptor = typename ToCoordinatesOp::Adaptor; 1081 using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern; 1082 LogicalResult 1083 matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor, 1084 ConversionPatternRewriter &rewriter) const override { 1085 // Replace the requested coordinates access with corresponding field. 1086 // The view is restricted to the actual size to ensure clients 1087 // of this operation truly observe size, not capacity! 1088 Location loc = op.getLoc(); 1089 Level lvl = op.getLevel(); 1090 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 1091 op.getTensor().getType()); 1092 auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl); 1093 if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) { 1094 auto size = desc.getCrdMemSize(rewriter, loc, lvl); 1095 mem = genSliceToSize(rewriter, loc, mem, size); 1096 } 1097 rewriter.replaceOp(op, mem); 1098 return success(); 1099 } 1100 }; 1101 1102 /// Sparse codegen rule for accessing the linear coordinates buffer. 1103 class SparseToCoordinatesBufferConverter 1104 : public OpConversionPattern<ToCoordinatesBufferOp> { 1105 public: 1106 using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor; 1107 using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern; 1108 LogicalResult 1109 matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor, 1110 ConversionPatternRewriter &rewriter) const override { 1111 // Replace the requested coordinates access with corresponding field. 1112 // The view is restricted to the actual size to ensure clients 1113 // of this operation truly observe size, not capacity! 1114 Location loc = op.getLoc(); 1115 Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart(); 1116 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 1117 op.getTensor().getType()); 1118 auto mem = desc.getAOSMemRef(); 1119 auto size = desc.getCrdMemSize(rewriter, loc, lvl); 1120 rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); 1121 return success(); 1122 } 1123 }; 1124 1125 /// Sparse codegen rule for value accesses. 1126 class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> { 1127 public: 1128 using OpAdaptor = typename ToValuesOp::Adaptor; 1129 using OpConversionPattern<ToValuesOp>::OpConversionPattern; 1130 LogicalResult 1131 matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor, 1132 ConversionPatternRewriter &rewriter) const override { 1133 // Replace the requested values access with corresponding field. 1134 // The view is restricted to the actual size to ensure clients 1135 // of this operation truly observe size, not capacity! 1136 Location loc = op.getLoc(); 1137 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 1138 op.getTensor().getType()); 1139 auto mem = desc.getValMemRef(); 1140 auto size = desc.getValMemSize(rewriter, loc); 1141 rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); 1142 return success(); 1143 } 1144 }; 1145 1146 /// Sparse codegen rule for the convert operator. 1147 class SparseConvertConverter : public OpConversionPattern<ConvertOp> { 1148 public: 1149 using OpConversionPattern::OpConversionPattern; 1150 LogicalResult 1151 matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor, 1152 ConversionPatternRewriter &rewriter) const override { 1153 SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType()); 1154 SparseTensorEncodingAttr encSrc = 1155 getSparseTensorEncoding(op.getSource().getType()); 1156 // The output tensor can not be a slice and those cases should have been 1157 // rejected by ConvertOp::verify() already. 1158 assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices."); 1159 // Different encoding (except for different bitwidth) should be handled by 1160 // rewriting. 1161 // We need further rewrites if the input tensor is a slice too. 1162 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() || 1163 encSrc.isSlice()) { 1164 return failure(); 1165 } 1166 1167 Type retElemTp = op.getResult().getType().getElementType(); 1168 Type srcElemTp = op.getSource().getType().getElementType(); 1169 // Fold the trivial cases. 1170 if (retElemTp == srcElemTp && encDst == encSrc) { 1171 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); 1172 return success(); 1173 } 1174 // 1175 // Do element-wise type conversion without using InsertOp. 1176 // 1177 // for each memref in srcTensor: 1178 // dst = memref.alloc 1179 // if srcMemRefType != dstMemRefType: 1180 // for every dst[i] = cast(src[i]) 1181 // else: 1182 // dst = memref.copy(src) 1183 Location loc = op.getLoc(); 1184 auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(), 1185 op.getSource().getType()); 1186 SmallVector<Value> fields; 1187 foreachFieldAndTypeInSparseTensor( 1188 SparseTensorType(cast<RankedTensorType>(op.getResult().getType())), 1189 [&rewriter, &fields, srcDesc, 1190 loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl, 1191 LevelType /*lt*/) -> bool { 1192 // Simply reuses the storage specifier as it is an SSA value. 1193 if (fKind == SparseTensorFieldKind::StorageSpec) { 1194 fields.push_back(srcDesc.getSpecifier()); 1195 } else { 1196 // Allocates new memrefs 1197 Value srcMem = srcDesc.getMemRefField(fIdx); 1198 // TODO: We can instead use the actual memSize in specifier, that 1199 // would require a subViewOp to avoid overflow when copying 1200 // values. 1201 Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0); 1202 auto dstMem = rewriter.create<memref::AllocOp>( 1203 loc, cast<MemRefType>(fTp), sz); 1204 if (fTp != srcMem.getType()) { 1205 // Converts elements type. 1206 scf::buildLoopNest( 1207 rewriter, loc, constantIndex(rewriter, loc, 0), sz, 1208 constantIndex(rewriter, loc, 1), 1209 [srcMem, &dstMem](OpBuilder &builder, Location loc, 1210 ValueRange ivs) { 1211 Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs); 1212 Value casted = genCast(builder, loc, v, 1213 dstMem.getType().getElementType()); 1214 builder.create<memref::StoreOp>(loc, casted, dstMem, ivs); 1215 }); 1216 } else { 1217 // TODO: We can even reuse the same memref for the new tensor, 1218 // but that requires a `ref-counting` based memory management 1219 // for shared memrefs between multiple sparse tensors. 1220 rewriter.create<memref::CopyOp>(loc, srcMem, dstMem); 1221 } 1222 fields.push_back(dstMem); 1223 } 1224 return true; 1225 }); 1226 1227 rewriter.replaceOpWithMultiple(op, {fields}); 1228 return success(); 1229 } 1230 }; 1231 1232 class SparseExtractSliceConverter 1233 : public OpConversionPattern<tensor::ExtractSliceOp> { 1234 public: 1235 using OpConversionPattern::OpConversionPattern; 1236 LogicalResult 1237 matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor, 1238 ConversionPatternRewriter &rewriter) const override { 1239 Location loc = op.getLoc(); 1240 MLIRContext *ctx = op.getContext(); 1241 auto srcEnc = getSparseTensorEncoding(op.getSourceType()); 1242 auto dstEnc = getSparseTensorEncoding(op.getResult().getType()); 1243 // TODO: We should check these in ExtractSliceOp::verify. 1244 if (!srcEnc || !dstEnc || !dstEnc.isSlice()) 1245 return failure(); 1246 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices()); 1247 1248 SmallVector<Value> fields; 1249 auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields, 1250 op.getSource().getType()); 1251 1252 auto newSpec = rewriter.create<StorageSpecifierInitOp>( 1253 loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier()); 1254 desc.setSpecifier(newSpec); 1255 1256 // Fills in slice information. 1257 for (auto [idx, offset, size, stride] : llvm::enumerate( 1258 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) { 1259 Dimension dim = idx; 1260 1261 Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset); 1262 Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size); 1263 Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride); 1264 // TODO: We could probably only set dynamic value here. But it would 1265 // requires us to fill the hole when casting a static slice to dynamic 1266 // slice. 1267 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset, 1268 dim, offsetV); 1269 1270 // FIXME: we need to distinguish level sizes and dimension size for slices 1271 // here. Maybe we should store slice level sizes in a different array 1272 // instead of reusing it. 1273 assert(srcEnc.isIdentity()); 1274 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim, 1275 sizeV); 1276 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride, 1277 dim, strideV); 1278 } 1279 1280 // NOTE: we can not generate tuples directly from descriptor here, as the 1281 // descriptor is holding the original type, yet we want the slice type 1282 // here (they shared every memref but with an updated specifier). 1283 rewriter.replaceOpWithMultiple(op, {desc.getFields()}); 1284 return success(); 1285 } 1286 }; 1287 1288 /// Sparse codegen rule for number of entries operator. 1289 class SparseNumberOfEntriesConverter 1290 : public OpConversionPattern<NumberOfEntriesOp> { 1291 public: 1292 using OpConversionPattern::OpConversionPattern; 1293 LogicalResult 1294 matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor, 1295 ConversionPatternRewriter &rewriter) const override { 1296 // Query memSizes for the actually stored values. 1297 // FIXME: the nse value computed in this way might be wrong when there is 1298 // any "loose_compressed" level. 1299 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 1300 op.getTensor().getType()); 1301 rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc())); 1302 return success(); 1303 } 1304 }; 1305 1306 struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> { 1307 using OpConversionPattern::OpConversionPattern; 1308 LogicalResult 1309 matchAndRewrite(AssembleOp op, OpAdaptor adaptor, 1310 ConversionPatternRewriter &rewriter) const override { 1311 Location loc = op.getLoc(); 1312 const auto stt = getSparseTensorType(op.getResult()); 1313 1314 SmallVector<Value> fields; 1315 1316 foreachFieldAndTypeInSparseTensor( 1317 stt, 1318 [&rewriter, &fields, &op, &stt, 1319 loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, 1320 Level /*lvl*/, LevelType lt) -> bool { 1321 assert(fields.size() == fIdx); 1322 if (fKind == SparseTensorFieldKind::StorageSpec) { 1323 fields.push_back( 1324 SparseTensorSpecifier::getInitValue(rewriter, loc, stt)); 1325 } else { 1326 // Else simply takes the inputs. 1327 Value tensor = fKind == SparseTensorFieldKind::ValMemRef 1328 ? op.getValues() 1329 : op.getLevels()[fIdx]; 1330 // TODO: handle batch. 1331 TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor); 1332 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) { 1333 // Flattens the buffer to batchLvlRank. 1334 auto reassoc = getReassociationForFlattening( 1335 mem.getType(), stt.getBatchLvlRank()); 1336 mem = rewriter.create<memref::CastOp>( 1337 loc, fType, 1338 rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc)); 1339 } else { 1340 mem = rewriter.create<memref::CastOp>(loc, fType, mem); 1341 } 1342 fields.push_back(mem); 1343 } 1344 return true; 1345 }); 1346 1347 MutSparseTensorDescriptor desc(stt, fields); 1348 Value c0 = constantIndex(rewriter, loc, 0); 1349 Value c1 = constantIndex(rewriter, loc, 1); 1350 Value c2 = constantIndex(rewriter, loc, 2); 1351 Value posBack = c0; // index to the last value in the position array 1352 Value memSize = c1; // memory size for current array 1353 1354 Level trailCOOStart = stt.getAoSCOOStart(); 1355 Level trailCOORank = stt.getLvlRank() - trailCOOStart; 1356 // Sets up SparseTensorSpecifier. 1357 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { 1358 assert(!ShapedType::isDynamic(stt.getDimShape()[lvl])); 1359 1360 // Sets up the level size. 1361 auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]); 1362 desc.setLvlSize(rewriter, loc, lvl, lvlSize); 1363 // We use a single AOS array to store the trailing COO, so there is only 1364 // one memory size to set for the entire COO section. 1365 if (lvl > trailCOOStart) 1366 continue; 1367 1368 // Sets up the memory size by reading the last value in position array. 1369 LevelType lt = stt.getLvlType(lvl); 1370 // Simply forwards the position index when this is a dense level. 1371 if (lt.isa<LevelFormat::Dense>()) { 1372 memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize); 1373 posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1); 1374 continue; 1375 } 1376 if (lt.isa<LevelFormat::Batch>()) { 1377 // Skips batch levels as it is not linearized. 1378 // FIXME: this assumes that every batch has the same number of nse, need 1379 // to be generalized to handle varied-size batches. 1380 continue; 1381 } 1382 1383 if (isWithPosLT(lt)) { 1384 assert(isCompressedLT(lt) || isLooseCompressedLT(lt)); 1385 if (isLooseCompressedLT(lt)) { 1386 memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2); 1387 posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1); 1388 } else { 1389 assert(isCompressedLT(lt)); 1390 posBack = memSize; 1391 memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1); 1392 } 1393 desc.setPosMemSize(rewriter, loc, lvl, memSize); 1394 // The last value in position array is the memory size for next level. 1395 // FIXME: this assumes that every batch has the same number of nse, need 1396 // to be generalized to handle varied-size batches. 1397 SmallVector<Value> batched(stt.getBatchLvlRank(), 1398 constantIndex(rewriter, loc, 0)); 1399 batched.push_back(posBack); 1400 memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched); 1401 posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1); 1402 } 1403 assert(isWithCrdLT(lt) && lvl <= trailCOOStart); 1404 // FIXME: This seems to be unnecessarily complex, can we simplify it? 1405 if (lvl == trailCOOStart) { 1406 Value cooSz = rewriter.create<arith::MulIOp>( 1407 loc, memSize, constantIndex(rewriter, loc, trailCOORank)); 1408 desc.setCrdMemSize(rewriter, loc, lvl, cooSz); 1409 } else { 1410 desc.setCrdMemSize(rewriter, loc, lvl, memSize); 1411 } 1412 } 1413 desc.setValMemSize(rewriter, loc, memSize); 1414 1415 rewriter.replaceOpWithMultiple(op, {desc.getFields()}); 1416 return success(); 1417 } 1418 }; 1419 1420 struct SparseDisassembleOpConverter 1421 : public OpConversionPattern<DisassembleOp> { 1422 using OpConversionPattern::OpConversionPattern; 1423 SparseDisassembleOpConverter(const TypeConverter &typeConverter, 1424 MLIRContext *context) 1425 : OpConversionPattern(typeConverter, context) {} 1426 1427 LogicalResult 1428 matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor, 1429 ConversionPatternRewriter &rewriter) const override { 1430 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), 1431 op.getTensor().getType()); 1432 Location loc = op.getLoc(); 1433 SmallVector<Value> retMem; 1434 SmallVector<Value> retLen; 1435 desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem, 1436 &retLen](FieldIndex fid, 1437 SparseTensorFieldKind fKind, 1438 Level lvl, LevelType lt) -> bool { 1439 if (fKind == SparseTensorFieldKind::StorageSpec) 1440 return true; 1441 SparseTensorType stt(desc.getRankedTensorType()); 1442 Value sz, src; 1443 TypedValue<BaseMemRefType> dst; 1444 if (fKind == SparseTensorFieldKind::ValMemRef) { 1445 sz = desc.getValMemSize(rewriter, loc); 1446 src = desc.getValMemRef(); 1447 dst = genToMemref(rewriter, loc, op.getOutValues()); 1448 1449 retMem.push_back(dst); 1450 Type valLenTp = op.getValLen().getType(); 1451 retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp)); 1452 } else { 1453 assert(fKind == SparseTensorFieldKind::PosMemRef || 1454 fKind == SparseTensorFieldKind::CrdMemRef); 1455 1456 sz = fKind == SparseTensorFieldKind::PosMemRef 1457 ? desc.getPosMemSize(rewriter, loc, lvl) 1458 : desc.getCrdMemSize(rewriter, loc, lvl); 1459 src = desc.getMemRefField(fid); 1460 dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]); 1461 retMem.push_back(dst); 1462 // Retrieves the corresponding level length type. 1463 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()]; 1464 retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp)); 1465 } 1466 Value flatOut = dst; 1467 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) { 1468 auto reassoc = 1469 getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank()); 1470 flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc); 1471 } 1472 Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz); 1473 Value srcMem = genSliceToSize(rewriter, loc, src, sz); 1474 rewriter.create<memref::CopyOp>(loc, srcMem, dstMem); 1475 return true; 1476 }); 1477 1478 // Converts MemRefs back to Tensors. 1479 SmallVector<Value> retValues = llvm::to_vector( 1480 llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value { 1481 return rewriter.create<bufferization::ToTensorOp>(loc, v); 1482 })); 1483 // Appends the actual memory length used in each buffer returned. 1484 retValues.append(retLen.begin(), retLen.end()); 1485 rewriter.replaceOp(op, retValues); 1486 return success(); 1487 } 1488 }; 1489 1490 struct SparseNewConverter : public OpConversionPattern<NewOp> { 1491 using OpConversionPattern::OpConversionPattern; 1492 LogicalResult 1493 matchAndRewrite(NewOp op, OpAdaptor adaptor, 1494 ConversionPatternRewriter &rewriter) const override { 1495 Location loc = op.getLoc(); 1496 const auto dstTp = getSparseTensorType(op.getResult()); 1497 // Creating COO with NewOp is handled by direct IR codegen. All other cases 1498 // are handled by rewriting. 1499 if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0) 1500 return failure(); 1501 1502 // Implement as follows: 1503 // %reader = @createCheckedSparseTensorReader(%filename) 1504 // %nse = @getSparseTensorNSE(%reader) 1505 // %coo = bufferization.alloc_tensor an ordered COO with 1506 // dst dim ordering, size_hint = %nse 1507 // %coordinates = sparse_tensor.coordinates_buffer(%coo) 1508 // %values = sparse_tensor.values(%coo) 1509 // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values) 1510 // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values) 1511 // update storage specifier 1512 // @delSparseTensorReader(%reader) 1513 SmallVector<Value> dimSizesValues; 1514 Value dimSizesBuffer; 1515 Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0], 1516 dimSizesValues, dimSizesBuffer); 1517 1518 // Get the number of stored entries. 1519 const Type indexTp = rewriter.getIndexType(); 1520 Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE", 1521 {indexTp}, {reader}, EmitCInterface::Off) 1522 .getResult(0); 1523 1524 // Construct the lvl sizes and the dim2lvl/lvl2dim buffers. 1525 SmallVector<Value> lvlSizesValues; 1526 Value dim2lvlBuffer; 1527 Value lvl2dimBuffer; 1528 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer, 1529 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer); 1530 1531 // Construct allocation for each field. 1532 Value sizeHint = nse; 1533 SmallVector<Value> fields; 1534 createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint, 1535 lvlSizesValues, fields); 1536 1537 // Read the COO tensor data. 1538 MutSparseTensorDescriptor desc(dstTp, fields); 1539 Value xs = desc.getAOSMemRef(); 1540 Value ys = desc.getValMemRef(); 1541 const Type boolTp = rewriter.getIntegerType(1); 1542 const Type elemTp = dstTp.getElementType(); 1543 const Type crdTp = dstTp.getCrdType(); 1544 SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers", 1545 overheadTypeFunctionSuffix(crdTp), 1546 primaryTypeFunctionSuffix(elemTp)}; 1547 Value isSorted = 1548 createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp}, 1549 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys}, 1550 EmitCInterface::On) 1551 .getResult(0); 1552 1553 // If the destination tensor is a sorted COO, we need to sort the COO tensor 1554 // data if the input elements aren't sorted yet. 1555 const Level lvlRank = dstTp.getLvlRank(); 1556 if (dstTp.isOrderedLvl(lvlRank - 1)) { 1557 Value kFalse = constantI1(rewriter, loc, false); 1558 Value notSorted = rewriter.create<arith::CmpIOp>( 1559 loc, arith::CmpIPredicate::eq, isSorted, kFalse); 1560 scf::IfOp ifOp = 1561 rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false); 1562 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1563 auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank); 1564 rewriter.create<SortOp>(loc, nse, xs, ValueRange{ys}, xPerm, 1565 rewriter.getIndexAttr(0), 1566 SparseTensorSortKind::HybridQuickSort); 1567 rewriter.setInsertionPointAfter(ifOp); 1568 } 1569 1570 // Set PosMemRef0[1] = nse. 1571 const Value c1 = constantIndex(rewriter, loc, 1); 1572 const Value posMemref0 = desc.getPosMemRef(0); 1573 const Type posTp = dstTp.getPosType(); 1574 const Value posNse = genCast(rewriter, loc, nse, posTp); 1575 rewriter.create<memref::StoreOp>(loc, posNse, posMemref0, c1); 1576 1577 // Update storage specifier. 1578 Value coordinatesSize = rewriter.create<arith::MulIOp>( 1579 loc, nse, constantIndex(rewriter, loc, lvlRank)); 1580 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0, 1581 coordinatesSize); 1582 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize, 1583 std::nullopt, nse); 1584 1585 // Release the sparse tensor reader. 1586 createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, 1587 EmitCInterface::Off); 1588 1589 // Replace operation with resulting memrefs. 1590 rewriter.replaceOpWithMultiple(op, {fields}); 1591 return success(); 1592 } 1593 }; 1594 1595 struct SparseHasRuntimeLibraryConverter 1596 : public OpConversionPattern<HasRuntimeLibraryOp> { 1597 using OpConversionPattern::OpConversionPattern; 1598 LogicalResult 1599 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor, 1600 ConversionPatternRewriter &rewriter) const override { 1601 auto i1Type = rewriter.getI1Type(); 1602 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 1603 op, i1Type, rewriter.getIntegerAttr(i1Type, 0)); 1604 return success(); 1605 } 1606 }; 1607 1608 } // namespace 1609 1610 //===----------------------------------------------------------------------===// 1611 // Public method for populating conversion rules. 1612 //===----------------------------------------------------------------------===// 1613 1614 /// Populates the given patterns list with conversion rules required for 1615 /// the sparsification of linear algebra operations. 1616 void mlir::populateSparseTensorCodegenPatterns( 1617 const TypeConverter &typeConverter, RewritePatternSet &patterns, 1618 bool createSparseDeallocs, bool enableBufferInitialization) { 1619 patterns.add< 1620 SparseAssembleOpConverter, SparseDisassembleOpConverter, 1621 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter, 1622 SparseCastConverter, SparseExtractSliceConverter, 1623 SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter, 1624 SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter, 1625 SparseSliceGetterOpConverter<ToSliceOffsetOp, 1626 StorageSpecifierKind::DimOffset>, 1627 SparseSliceGetterOpConverter<ToSliceStrideOp, 1628 StorageSpecifierKind::DimStride>, 1629 SparseToPositionsConverter, SparseToCoordinatesConverter, 1630 SparseToCoordinatesBufferConverter, SparseToValuesConverter, 1631 SparseConvertConverter, SparseNewConverter, 1632 SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>( 1633 typeConverter, patterns.getContext()); 1634 patterns.add<SparseTensorDeallocConverter>( 1635 typeConverter, patterns.getContext(), createSparseDeallocs); 1636 patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>( 1637 typeConverter, patterns.getContext(), enableBufferInitialization); 1638 } 1639