Searched refs:TensorLoopId (Results 1 – 4 of 4) sorted by relevance
| /llvm-project/mlir/include/mlir/Dialect/SparseTensor/Utils/ |
| H A D | Merger.h | 44 using TensorLoopId = unsigned; variable 255 constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const { in makeTensorLoopId() 346 constexpr TensorId tensor(TensorLoopId b) const { return b % numTensors; } in tensor() 348 constexpr LoopId loop(TensorLoopId b) const { return b / numTensors; } in loop() 358 constexpr bool isOutTensor(TensorLoopId b, LoopId i) const { in isOutTensor() 405 LevelType getLvlType(TensorLoopId b) const { in getLvlType() 420 std::optional<Level> getLvl(TensorLoopId b) const { in getLvl() 436 TensorLoopId, TensorId, std::optional<Level>, LevelType, bool)>; 452 for (const TensorLoopId b : bits.set_bits()) { in foreachTensorLoopId() 501 bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const { in isLvlWithNonTrivialIdxExp() [all …]
|
| /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/ |
| H A D | CodegenEnv.h | 78 constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const { in lat() 87 LevelType lt(TensorLoopId b) const { return latticeMerger.getLvlType(b); }
|
| /llvm-project/mlir/lib/Dialect/SparseTensor/Utils/ |
| H A D | Merger.cpp | 296 const TensorLoopId b = makeTensorLoopId(t, i); in addLat() 474 const TensorLoopId be = simple.size(); in simplifyCond() 475 TensorLoopId offset = 0; // relative to the end in simplifyCond() 480 if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) { in simplifyCond() 508 for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) in latGT() 672 for (TensorLoopId b : bits.set_bits()) { in hasAnySparse() 681 for (TensorLoopId b : bits.set_bits()) in hasSparseIdxReduction() 921 for (TensorLoopId b = 0, be = bits.size(); b < be; b++) { in dumpBits()
|
| /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/ |
| H A D | Sparsification.cpp | 935 [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt, in genIf() 1015 li, [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl, in getAllTidLvlsInLatPoints()
|