xref: /llvm-project/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (revision f607102a0d6be0e2aebc1bfaed2ed0a6ae020145)
1 //===- SparseTensor.h - Sparse tensor dialect -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_
10 #define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_
11 
12 #include "mlir/Bytecode/BytecodeOpInterface.h"
13 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/IR/TensorEncoding.h"
20 #include "mlir/Interfaces/ControlFlowInterfaces.h"
21 #include "mlir/Interfaces/InferTypeOpInterface.h"
22 #include "mlir/Interfaces/LoopLikeInterface.h"
23 #include "mlir/Interfaces/SideEffectInterfaces.h"
24 
25 #include "llvm/ADT/bit.h"
26 
27 //===----------------------------------------------------------------------===//
28 //
29 // Type aliases to help code be more self-documenting. Unfortunately
30 // these are not type-checked, so they only provide documentation rather
31 // than doing anything to prevent mixups.
32 //
33 //===----------------------------------------------------------------------===//
34 
35 namespace mlir {
36 namespace sparse_tensor {
37 
38 /// The type of dimension identifiers and dimension-ranks.
39 using Dimension = uint64_t;
40 
41 /// The type of level identifiers and level-ranks.
42 using Level = uint64_t;
43 
44 /// The type for individual components of a compile-time shape,
45 /// including the value `ShapedType::kDynamic` (for shapes).
46 using Size = int64_t;
47 
48 /// A simple structure that encodes a range of levels in the sparse tensors
49 /// that forms a COO segment.
50 struct COOSegment {
51   std::pair<Level, Level> lvlRange; // [low, high)
52   bool isSoA;
53 
54   bool isAoS() const { return !isSoA; }
55   bool isSegmentStart(Level l) const { return l == lvlRange.first; }
56   bool inSegment(Level l) const {
57     return l >= lvlRange.first && l < lvlRange.second;
58   }
59 };
60 
61 /// A simple wrapper to encode a bitset of (at most 64) levels, currently used
62 /// by `sparse_tensor.iterate` operation for the set of levels on which the
63 /// coordinates should be loaded.
64 class I64BitSet {
65   uint64_t storage = 0;
66 
67 public:
68   using const_set_bits_iterator = llvm::const_set_bits_iterator_impl<I64BitSet>;
69   const_set_bits_iterator begin() const {
70     return const_set_bits_iterator(*this);
71   }
72   const_set_bits_iterator end() const {
73     return const_set_bits_iterator(*this, -1);
74   }
75   iterator_range<const_set_bits_iterator> bits() const {
76     return make_range(begin(), end());
77   }
78 
79   I64BitSet() = default;
80   explicit I64BitSet(uint64_t bits) : storage(bits) {}
81   operator uint64_t() const { return storage; }
82 
83   I64BitSet &set(unsigned i) {
84     assert(i < 64);
85     storage |= static_cast<uint64_t>(0x01u) << i;
86     return *this;
87   }
88 
89   I64BitSet &operator|=(I64BitSet lhs) {
90     storage |= static_cast<uint64_t>(lhs);
91     return *this;
92   }
93 
94   I64BitSet &lshift(unsigned offset) {
95     storage = storage << offset;
96     return *this;
97   }
98 
99   bool isSubSetOf(const I64BitSet p) const {
100     I64BitSet tmp = *this;
101     tmp |= p;
102     return tmp == p;
103   }
104 
105   // Needed by `llvm::const_set_bits_iterator_impl`.
106   int find_first() const { return min(); }
107   int find_next(unsigned prev) const {
108     if (prev >= max() - 1)
109       return -1;
110 
111     uint64_t b = storage >> (prev + static_cast<int64_t>(1));
112     assert(b != 0);
113 
114     return llvm::countr_zero(b) + prev + static_cast<int64_t>(1);
115   }
116 
117   bool operator[](unsigned i) const {
118     assert(i < 64);
119     return (storage & (static_cast<int64_t>(1) << i)) != 0;
120   }
121   unsigned min() const {
122     unsigned m = llvm::countr_zero(storage);
123     return m == 64 ? -1 : m;
124   }
125   unsigned max() const { return 64 - llvm::countl_zero(storage); }
126   unsigned count() const { return llvm::popcount(storage); }
127   bool empty() const { return storage == 0; }
128 };
129 
130 } // namespace sparse_tensor
131 } // namespace mlir
132 
133 //===----------------------------------------------------------------------===//
134 // TableGen-defined classes
135 //===----------------------------------------------------------------------===//
136 
137 #define GET_ATTRDEF_CLASSES
138 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.h.inc"
139 
140 #define GET_ATTRDEF_CLASSES
141 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.h.inc"
142 
143 #define GET_TYPEDEF_CLASSES
144 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.h.inc"
145 
146 #define GET_OP_CLASSES
147 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.h.inc"
148 
149 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.h.inc"
150 
151 //===----------------------------------------------------------------------===//
152 // Additional convenience methods.
153 //===----------------------------------------------------------------------===//
154 
155 namespace mlir {
156 namespace sparse_tensor {
157 
158 /// Convenience method to abbreviate casting `getType()`.
159 template <typename T>
160 inline RankedTensorType getRankedTensorType(T &&t) {
161   assert(static_cast<bool>(std::forward<T>(t)) &&
162          "getRankedTensorType got null argument");
163   return dyn_cast<RankedTensorType>(std::forward<T>(t).getType());
164 }
165 
166 /// Convenience method to abbreviate casting `getType()`.
167 template <typename T>
168 inline MemRefType getMemRefType(T &&t) {
169   assert(static_cast<bool>(std::forward<T>(t)) &&
170          "getMemRefType got null argument");
171   return cast<MemRefType>(std::forward<T>(t).getType());
172 }
173 
174 /// Convenience method to get a sparse encoding attribute from a type.
175 /// Returns null-attribute for any type without an encoding.
176 SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
177 
178 /// Returns true iff the type range has any sparse tensor type.
179 inline bool hasAnySparseType(TypeRange types) {
180   return llvm::any_of(types, [](Type type) {
181     return getSparseTensorEncoding(type) != nullptr;
182   });
183 }
184 
185 /// Returns true iff MLIR operand has any sparse operand.
186 inline bool hasAnySparseOperand(Operation *op) {
187   return hasAnySparseType(op->getOperands().getTypes());
188 }
189 
190 /// Returns true iff MLIR operand has any sparse result.
191 inline bool hasAnySparseResult(Operation *op) {
192   return hasAnySparseType(op->getResults().getTypes());
193 }
194 
195 /// Returns true iff MLIR operand has any sparse operand or result.
196 inline bool hasAnySparseOperandOrResult(Operation *op) {
197   return hasAnySparseOperand(op) || hasAnySparseResult(op);
198 }
199 
200 /// Returns true iff MLIR operation has any sparse tensor with non-identity
201 /// dim2lvl maps.
202 bool hasAnyNonIdentityOperandsOrResults(Operation *op);
203 
204 //
205 // Inference.
206 //
207 
208 /// Given the dimToLvl map, infers the lvlToDim map, or returns
209 /// empty Affine map when inference fails.
210 AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context);
211 
212 /// Returns the lvlToDim map for the given dimToLvl map specific
213 /// to the block sparse cases.
214 /// Asserts on failure (so only use when known to succeed).
215 AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context);
216 
217 /// Given the dimToLvl map, returns the block sizes in a vector.
218 /// For instance, a 2x3 block will return [2, 3]. Unblocked dimension i
219 /// will return 0, and i floordiv 1, i mod 1 will return 1. Therefore,
220 /// the example below will return [0, 1].
221 /// map = ( i, j ) ->
222 ///       ( i : dense,
223 ///         j floordiv 1 : compressed,
224 ///         j mod 1      : dense
225 ///       )
226 /// Only valid block sparsity will be accepted.
227 SmallVector<unsigned> getBlockSize(AffineMap dimToLvl);
228 
229 /// Given the dimToLvl map, returns if it's block sparsity.
230 bool isBlockSparsity(AffineMap dimToLvl);
231 
232 //
233 // Reordering.
234 //
235 
236 /// Convenience method to translate the given level to the corresponding
237 /// dimension.
238 /// Requires: `enc` has a permuted dim2lvl map and `0 <= l < lvlRank`.
239 Dimension toDim(SparseTensorEncodingAttr enc, Level l);
240 
241 /// Convenience method to translate the given dimension to the corresponding
242 /// level.
243 /// Requires: `enc` has a permuted dim2lvl map and `0 <= d < dimRank`.
244 Level toLvl(SparseTensorEncodingAttr enc, Dimension d);
245 
246 } // namespace sparse_tensor
247 } // namespace mlir
248 
249 #endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_
250