xref: /llvm-project/mlir/lib/ExecutionEngine/SparseTensor/MapRef.cpp (revision 306f4c306a3aae6ce0d92452b2f8fb72cf1908b0)
1d3af6535SAart Bik //===- MapRef.cpp - A dim2lvl/lvl2dim map reference wrapper ---------------===//
2d3af6535SAart Bik //
3d3af6535SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d3af6535SAart Bik // See https://llvm.org/LICENSE.txt for license information.
5d3af6535SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d3af6535SAart Bik //
7d3af6535SAart Bik //===----------------------------------------------------------------------===//
8d3af6535SAart Bik 
9d3af6535SAart Bik #include "mlir/ExecutionEngine/SparseTensor/MapRef.h"
10*306f4c30SAart Bik #include "mlir/Dialect/SparseTensor/IR/Enums.h"
11d3af6535SAart Bik 
MapRef(uint64_t d,uint64_t l,const uint64_t * d2l,const uint64_t * l2d)12d3af6535SAart Bik mlir::sparse_tensor::MapRef::MapRef(uint64_t d, uint64_t l, const uint64_t *d2l,
13d3af6535SAart Bik                                     const uint64_t *l2d)
14d3af6535SAart Bik     : dimRank(d), lvlRank(l), dim2lvl(d2l), lvl2dim(l2d),
15d3af6535SAart Bik       isPermutation(isPermutationMap()) {
16d3af6535SAart Bik   if (isPermutation) {
17*306f4c30SAart Bik     for (uint64_t l = 0; l < lvlRank; l++)
18*306f4c30SAart Bik       assert(lvl2dim[dim2lvl[l]] == l);
19d3af6535SAart Bik   }
20d3af6535SAart Bik }
21d3af6535SAart Bik 
isPermutationMap() const22d3af6535SAart Bik bool mlir::sparse_tensor::MapRef::isPermutationMap() const {
23d3af6535SAart Bik   if (dimRank != lvlRank)
24d3af6535SAart Bik     return false;
25d3af6535SAart Bik   std::vector<bool> seen(dimRank, false);
26*306f4c30SAart Bik   for (uint64_t l = 0; l < lvlRank; l++) {
27*306f4c30SAart Bik     const uint64_t d = dim2lvl[l];
28*306f4c30SAart Bik     if (d >= dimRank || seen[d])
29d3af6535SAart Bik       return false;
30*306f4c30SAart Bik     seen[d] = true;
31d3af6535SAart Bik   }
32d3af6535SAart Bik   return true;
33d3af6535SAart Bik }
34*306f4c30SAart Bik 
isFloor(uint64_t l,uint64_t & i,uint64_t & c) const35*306f4c30SAart Bik bool mlir::sparse_tensor::MapRef::isFloor(uint64_t l, uint64_t &i,
36*306f4c30SAart Bik                                           uint64_t &c) const {
37*306f4c30SAart Bik   if (isEncodedFloor(dim2lvl[l])) {
38*306f4c30SAart Bik     i = decodeIndex(dim2lvl[l]);
39*306f4c30SAart Bik     c = decodeConst(dim2lvl[l]);
40*306f4c30SAart Bik     return true;
41*306f4c30SAart Bik   }
42*306f4c30SAart Bik   return false;
43*306f4c30SAart Bik }
44*306f4c30SAart Bik 
isMod(uint64_t l,uint64_t & i,uint64_t & c) const45*306f4c30SAart Bik bool mlir::sparse_tensor::MapRef::isMod(uint64_t l, uint64_t &i,
46*306f4c30SAart Bik                                         uint64_t &c) const {
47*306f4c30SAart Bik   if (isEncodedMod(dim2lvl[l])) {
48*306f4c30SAart Bik     i = decodeIndex(dim2lvl[l]);
49*306f4c30SAart Bik     c = decodeConst(dim2lvl[l]);
50*306f4c30SAart Bik     return true;
51*306f4c30SAart Bik   }
52*306f4c30SAart Bik   return false;
53*306f4c30SAart Bik }
54*306f4c30SAart Bik 
isMul(uint64_t d,uint64_t & i,uint64_t & c,uint64_t & ii) const55*306f4c30SAart Bik bool mlir::sparse_tensor::MapRef::isMul(uint64_t d, uint64_t &i, uint64_t &c,
56*306f4c30SAart Bik                                         uint64_t &ii) const {
57*306f4c30SAart Bik   if (isEncodedMul(lvl2dim[d])) {
58*306f4c30SAart Bik     i = decodeIndex(lvl2dim[d]);
59*306f4c30SAart Bik     c = decodeMulc(lvl2dim[d]);
60*306f4c30SAart Bik     ii = decodeMuli(lvl2dim[d]);
61*306f4c30SAart Bik     return true;
62*306f4c30SAart Bik   }
63*306f4c30SAart Bik   return false;
64*306f4c30SAart Bik }
65