Lines Matching +full:non +full:- +full:batch
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
57 //===----------------------------------------------------------------------===//
59 //===----------------------------------------------------------------------===//
79 // batch levels, we therefore return a dynamic shape memref instead.
83 // batch dimension.
94 //===----------------------------------------------------------------------===//
96 //===----------------------------------------------------------------------===//
98 static constexpr Level kInvalidLevel = -1u;
99 static constexpr Level kInvalidFieldIndex = -1u;
112 // Per-level storage.
135 // Non COO levels.
160 // memref<[batch] x ? x pos> positions
162 // memref<[batch] x ? x crd> coordinates
164 // memref<[batch] x ? x eltType> values
170 Level lvl, LevelType lt) -> bool {
188 LevelType) -> bool {
198 LevelType) -> bool {
203 numFields -= 1; // the last field is StorageSpecifier
204 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
219 stride = lvlRank - cooStart;
224 LevelType lt) -> bool {
237 //===----------------------------------------------------------------------===//
239 //===----------------------------------------------------------------------===//
320 return emitError() << "expect non-negative value or ? for slice offset";
456 return LevelFormat::Batch;
514 // Handle non-permutation maps.
523 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
575 // Process the data from the parsed dictionary value into struct-like data.
594 unsigned keyWordIndex = it - keys.begin();
614 // NOTE: the old syntax requires an all-or-nothing approach to
616 // to convert null-DSA into default/nop DSA.
706 // Construct struct-like storage for attribute.
724 printer << ") -> (";
727 // Print remaining members only for non-default values.
745 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
748 printer << 's' << map.getNumSymbols() - 1;
756 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
759 printer << 'd' << map.getNumDims() - 1 << " : "
760 << dimSlices[map.getNumDims() - 1];
763 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
766 printer << 'd' << map.getNumDims() - 1;
772 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
777 auto lastIndex = map.getNumResults() - 1;
797 !(it - 1)->isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>())
808 return it->isa<LevelPropNonDefault::SoA>() ==
819 return emitError() << "Batch lvlType can only be leading levels.";
834 if (it != lvlTypes.end() - 1)
862 // Before we can check that the level-rank is consistent/coherent
863 // across all fields, we need to define it. The source-of-truth for
864 // the `getLvlRank` method is the length of the level-types array,
866 // use that same source-of-truth here.
869 return emitError() << "expected a non-empty array for lvlTypes";
875 << "level-rank mismatch between dimToLvl and lvlTypes: "
890 << "dimension-rank mismatch between dimSlices and dimToLvl: "
896 << "dimSlices expected dimension-rank to match level-rank: "
906 // level-rank is coherent across all the fields.
912 // need only check that the dimension-rank of the tensor agrees with
913 // the dimension-rank of the encoding.
916 return emitError() << "expected non-scalar sparse tensor";
919 << "dimension-rank mismatch between encoding and tensor shape: "
991 //===----------------------------------------------------------------------===//
993 //===----------------------------------------------------------------------===//
1007 return !isUnique || isUniqueLvl(lvlRank - 1);
1014 // A non-unique compressed level at beginning (unless this is
1019 // Followed by n-2 non-unique singleton levels.
1020 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1031 //===----------------------------------------------------------------------===//
1033 //===----------------------------------------------------------------------===//
1149 it->second = conOp.getValue();
1156 if (conOp.getValue() != it->second)
1177 return stt && !stt->isIdentity();
1180 return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1181 llvm::any_of(op->getResults(), hasNonIdentityMap);
1186 assert(enc.isPermutation() && "Non permutation map not supported");
1195 assert(enc.isPermutation() && "Non permutation map not supported");
1239 //===----------------------------------------------------------------------===//
1241 //===----------------------------------------------------------------------===//
1256 return op->emitError(
1266 return op->emitError("requested slice data on non-slice tensor");
1270 return op->emitError("missing level argument");
1274 return op->emitError("requested level is out of bounds");
1277 return op->emitError(
1302 return op->emitError("the sparse-tensor must have static shape");
1304 return op->emitError("the sparse-tensor must have an encoding attribute");
1312 unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1314 return op->emitError("input/output trailing COO level-ranks don't match");
1321 return op->emitError("inconsistent number of fields between input/output");
1327 Level lvl, LevelType lt) -> bool {
1349 return op->emitError("input/output element-types don't match");
1593 // A -> B, B -> A ==> A
1743 return op->emitError() << regionName << " region must have exactly "
1749 return op->emitError() << regionName << " region argument " << (i + 1)
1755 return op->emitError() << regionName
1759 return op->emitError() << regionName << " region yield type mismatch";
1765 NamedAttrList attrs = (*this)->getAttrs();
1774 // non-empty region.
1806 // non-empty region.
1820 Block *parent = getOperation()->getBlock();
1822 cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1828 (def->getBlock() == absentBlock || def->getBlock() == parent))
1862 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1943 // Starts with `dimRank`-many coordinates.
1957 bodyBlock->getArguments().slice(0, dimRank),
1958 bodyBlock->getArguments()[dimRank],
1959 bodyBlock->getArguments().drop_front(dimRank + 1));
1965 const auto args = getBody()->getArguments();
1967 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1980 auto yield = cast<YieldOp>(getBody()->getTerminator());
2054 // compile-time constants.
2061 const char *message) -> LogicalResult {
2082 //===----------------------------------------------------------------------===//
2084 //===----------------------------------------------------------------------===//
2096 /// or simply "$lo" if $hi - $lo = 1
2117 /// or simply "$lo" if $hi - $lo = 1
2130 /// or simply "$lo" if $hi - $lo = 1
2140 /// or simply "$lo" if $hi - $lo = 1
2159 parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
2195 if (i != size - 1)
2251 // parse ": sparse_tensor.iter_space -> ret"
2319 // parse ": (sparse_tensor.iter_space, ...) -> ret"
2401 return emitOpError("must use last-level iterator to extract values. ");
2412 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2415 if (crd->getUsers().empty())
2416 toRemove.set(crd->getArgNumber());
2428 iterateOp.getBody()->eraseArguments(toRemove);
2443 I64BitSet set((1 << rank) - 1);
2460 // Starts with a list of user-provided loop arguments.
2462 bodyBlock->addArgument(v.getType(), v.getLoc());
2466 bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2469 bodyBlock->addArgument(
2523 "mismatch in number of loop-carried values and defined values");
2526 return op.emitOpError("required out-of-bound coordinates");
2622 I64BitSet set((1 << rank) - 1);
2623 // Generates all-zero case bits (they only serve as placeholders), which are
2767 //===----------------------------------------------------------------------===//
2769 //===----------------------------------------------------------------------===//