1 //===- MapRef.h - A dim2lvl/lvl2dim map encoding ----------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A dim2lvl/lvl2dim map encoding class, with utility methods. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_MAPREF_H 14 #define MLIR_EXECUTIONENGINE_SPARSETENSOR_MAPREF_H 15 16 #include <cinttypes> 17 18 #include <cassert> 19 #include <vector> 20 21 namespace mlir { 22 namespace sparse_tensor { 23 24 /// A class for capturing the sparse tensor type map with a compact encoding. 25 /// 26 /// Currently, the following situations are supported: 27 /// (1) map is a permutation 28 /// (2) map has affine ops (restricted set) 29 /// 30 /// The pushforward/backward operations are fast for (1) but incur some obvious 31 /// overhead for situation (2). 32 /// 33 class MapRef final { 34 public: 35 MapRef(uint64_t d, uint64_t l, const uint64_t *d2l, const uint64_t *l2d); 36 37 // 38 // Push forward maps from dimensions to levels. 39 // 40 41 // Map from dimRank in to lvlRank out. 42 template <typename T> pushforward(const T * in,T * out)43 inline void pushforward(const T *in, T *out) const { 44 if (isPermutation) { 45 for (uint64_t l = 0; l < lvlRank; l++) { 46 out[l] = in[dim2lvl[l]]; 47 } 48 } else { 49 uint64_t i, c; 50 for (uint64_t l = 0; l < lvlRank; l++) 51 if (isFloor(l, i, c)) { 52 out[l] = in[i] / c; 53 } else if (isMod(l, i, c)) { 54 out[l] = in[i] % c; 55 } else { 56 out[l] = in[dim2lvl[l]]; 57 } 58 } 59 } 60 61 // 62 // Push backward maps from levels to dimensions. 63 // 64 65 // Map from lvlRank in to dimRank out. 66 template <typename T> pushbackward(const T * in,T * out)67 inline void pushbackward(const T *in, T *out) const { 68 if (isPermutation) { 69 for (uint64_t d = 0; d < dimRank; d++) 70 out[d] = in[lvl2dim[d]]; 71 } else { 72 uint64_t i, c, ii; 73 for (uint64_t d = 0; d < dimRank; d++) 74 if (isMul(d, i, c, ii)) { 75 out[d] = in[i] + c * in[ii]; 76 } else { 77 out[d] = in[lvl2dim[d]]; 78 } 79 } 80 } 81 getDimRank()82 uint64_t getDimRank() const { return dimRank; } getLvlRank()83 uint64_t getLvlRank() const { return lvlRank; } 84 85 private: 86 bool isPermutationMap() const; 87 88 bool isFloor(uint64_t l, uint64_t &i, uint64_t &c) const; 89 bool isMod(uint64_t l, uint64_t &i, uint64_t &c) const; 90 bool isMul(uint64_t d, uint64_t &i, uint64_t &c, uint64_t &ii) const; 91 92 const uint64_t dimRank; 93 const uint64_t lvlRank; 94 const uint64_t *const dim2lvl; // non-owning pointer 95 const uint64_t *const lvl2dim; // non-owning pointer 96 const bool isPermutation; 97 }; 98 99 } // namespace sparse_tensor 100 } // namespace mlir 101 102 #endif // MLIR_EXECUTIONENGINE_SPARSETENSOR_MAPREF_H 103