Lines Matching defs:stt

107   const SparseTensorType stt(desc.getRankedTensorType());
109 const Level lvlRank = stt.getLvlRank();
111 const auto lt = stt.getLvlType(lvl);
118 Value posZero = constantZero(builder, loc, stt.getPosType());
137 Value valZero = constantZero(builder, loc, stt.getElementType());
157 SparseTensorType stt, ValueRange dynSizes,
159 const Dimension dimRank = stt.getDimRank();
163 for (const Size sz : stt.getDimShape())
174 SparseTensorType stt, bool enableInit,
178 Level lvlRank = stt.getLvlRank();
184 if (stt.isAllDense()) {
190 if (stt.getAoSCOOStart() == 0) {
194 } else if (lvlRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) {
209 stt,
210 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
217 field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
240 MutSparseTensorDescriptor desc(stt, fields);
241 Value posZero = constantZero(builder, loc, stt.getPosType());
242 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
244 const auto lt = stt.getLvlType(lvl);
277 const SparseTensorType stt(desc.getRankedTensorType());
278 const Level lvlRank = stt.getLvlRank();
329 const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0)
366 const SparseTensorType stt(desc.getRankedTensorType());
367 const Level lvlRank = stt.getLvlRank();
369 const auto lt = stt.getLvlType(lvl);
378 Type posType = stt.getPosType();
471 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
472 const Level lvlRank = stt.getLvlRank();
475 MutSparseTensorDescriptor desc(stt, fields);
482 const auto lt = stt.getLvlType(lvl);
515 if (!stt.isDenseLvl(lvlRank - 1))
527 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
531 const Level lvlRank = stt.getLvlRank();
533 std::string lvlType = toMLIRString(stt.getLvlType(l));
544 for (const auto sz : stt.getDimShape())
547 if (!stt.isIdentity())
548 nameOstream << stt.getDimToLvl() << "_";
549 nameOstream << stt.getElementType() << "_";
550 nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
1032 auto stt = getSparseTensorType(op.getDest());
1033 if (!stt.hasEncoding())
1035 assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
1312 const auto stt = getSparseTensorType(op.getResult());
1317 stt,
1318 [&rewriter, &fields, &op, &stt,
1324 SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1332 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1335 mem.getType(), stt.getBatchLvlRank());
1347 MutSparseTensorDescriptor desc(stt, fields);
1354 Level trailCOOStart = stt.getAoSCOOStart();
1355 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1357 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1358 assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1361 auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1369 LevelType lt = stt.getLvlType(lvl);
1397 SmallVector<Value> batched(stt.getBatchLvlRank(),
1441 SparseTensorType stt(desc.getRankedTensorType());
1467 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1469 getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());