Home
last modified time | relevance | path

Searched refs:TensorLoopId (Results 1 – 4 of 4) sorted by relevance

/llvm-project/mlir/include/mlir/Dialect/SparseTensor/Utils/
H A DMerger.h44 using TensorLoopId = unsigned; variable
255 constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const { in makeTensorLoopId()
346 constexpr TensorId tensor(TensorLoopId b) const { return b % numTensors; } in tensor()
348 constexpr LoopId loop(TensorLoopId b) const { return b / numTensors; } in loop()
358 constexpr bool isOutTensor(TensorLoopId b, LoopId i) const { in isOutTensor()
405 LevelType getLvlType(TensorLoopId b) const { in getLvlType()
420 std::optional<Level> getLvl(TensorLoopId b) const { in getLvl()
436 TensorLoopId, TensorId, std::optional<Level>, LevelType, bool)>;
452 for (const TensorLoopId b : bits.set_bits()) { in foreachTensorLoopId()
501 bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const { in isLvlWithNonTrivialIdxExp()
[all …]
/llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/
H A DCodegenEnv.h78 constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const { in lat()
87 LevelType lt(TensorLoopId b) const { return latticeMerger.getLvlType(b); }
/llvm-project/mlir/lib/Dialect/SparseTensor/Utils/
H A DMerger.cpp296 const TensorLoopId b = makeTensorLoopId(t, i); in addLat()
474 const TensorLoopId be = simple.size(); in simplifyCond()
475 TensorLoopId offset = 0; // relative to the end in simplifyCond()
480 if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) { in simplifyCond()
508 for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) in latGT()
672 for (TensorLoopId b : bits.set_bits()) { in hasAnySparse()
681 for (TensorLoopId b : bits.set_bits()) in hasSparseIdxReduction()
921 for (TensorLoopId b = 0, be = bits.size(); b < be; b++) { in dumpBits()
/llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/
H A DSparsification.cpp935 [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt, in genIf()
1015 li, [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl, in getAllTidLvlsInLatPoints()