xref: /llvm-project/mlir/lib/ExecutionEngine/SparseTensor/MapRef.cpp (revision 306f4c306a3aae6ce0d92452b2f8fb72cf1908b0)
1 //===- MapRef.cpp - A dim2lvl/lvl2dim map reference wrapper ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/ExecutionEngine/SparseTensor/MapRef.h"
10 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
11 
MapRef(uint64_t d,uint64_t l,const uint64_t * d2l,const uint64_t * l2d)12 mlir::sparse_tensor::MapRef::MapRef(uint64_t d, uint64_t l, const uint64_t *d2l,
13                                     const uint64_t *l2d)
14     : dimRank(d), lvlRank(l), dim2lvl(d2l), lvl2dim(l2d),
15       isPermutation(isPermutationMap()) {
16   if (isPermutation) {
17     for (uint64_t l = 0; l < lvlRank; l++)
18       assert(lvl2dim[dim2lvl[l]] == l);
19   }
20 }
21 
isPermutationMap() const22 bool mlir::sparse_tensor::MapRef::isPermutationMap() const {
23   if (dimRank != lvlRank)
24     return false;
25   std::vector<bool> seen(dimRank, false);
26   for (uint64_t l = 0; l < lvlRank; l++) {
27     const uint64_t d = dim2lvl[l];
28     if (d >= dimRank || seen[d])
29       return false;
30     seen[d] = true;
31   }
32   return true;
33 }
34 
isFloor(uint64_t l,uint64_t & i,uint64_t & c) const35 bool mlir::sparse_tensor::MapRef::isFloor(uint64_t l, uint64_t &i,
36                                           uint64_t &c) const {
37   if (isEncodedFloor(dim2lvl[l])) {
38     i = decodeIndex(dim2lvl[l]);
39     c = decodeConst(dim2lvl[l]);
40     return true;
41   }
42   return false;
43 }
44 
isMod(uint64_t l,uint64_t & i,uint64_t & c) const45 bool mlir::sparse_tensor::MapRef::isMod(uint64_t l, uint64_t &i,
46                                         uint64_t &c) const {
47   if (isEncodedMod(dim2lvl[l])) {
48     i = decodeIndex(dim2lvl[l]);
49     c = decodeConst(dim2lvl[l]);
50     return true;
51   }
52   return false;
53 }
54 
isMul(uint64_t d,uint64_t & i,uint64_t & c,uint64_t & ii) const55 bool mlir::sparse_tensor::MapRef::isMul(uint64_t d, uint64_t &i, uint64_t &c,
56                                         uint64_t &ii) const {
57   if (isEncodedMul(lvl2dim[d])) {
58     i = decodeIndex(lvl2dim[d]);
59     c = decodeMulc(lvl2dim[d]);
60     ii = decodeMuli(lvl2dim[d]);
61     return true;
62   }
63   return false;
64 }
65