1 //===- SparseTensor.h - Sparse tensor dialect -------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_ 10 #define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_ 11 12 #include "mlir/Bytecode/BytecodeOpInterface.h" 13 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 14 #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/OpDefinition.h" 18 #include "mlir/IR/OpImplementation.h" 19 #include "mlir/IR/TensorEncoding.h" 20 #include "mlir/Interfaces/ControlFlowInterfaces.h" 21 #include "mlir/Interfaces/InferTypeOpInterface.h" 22 #include "mlir/Interfaces/LoopLikeInterface.h" 23 #include "mlir/Interfaces/SideEffectInterfaces.h" 24 25 #include "llvm/ADT/bit.h" 26 27 //===----------------------------------------------------------------------===// 28 // 29 // Type aliases to help code be more self-documenting. Unfortunately 30 // these are not type-checked, so they only provide documentation rather 31 // than doing anything to prevent mixups. 32 // 33 //===----------------------------------------------------------------------===// 34 35 namespace mlir { 36 namespace sparse_tensor { 37 38 /// The type of dimension identifiers and dimension-ranks. 39 using Dimension = uint64_t; 40 41 /// The type of level identifiers and level-ranks. 42 using Level = uint64_t; 43 44 /// The type for individual components of a compile-time shape, 45 /// including the value `ShapedType::kDynamic` (for shapes). 46 using Size = int64_t; 47 48 /// A simple structure that encodes a range of levels in the sparse tensors 49 /// that forms a COO segment. 50 struct COOSegment { 51 std::pair<Level, Level> lvlRange; // [low, high) 52 bool isSoA; 53 54 bool isAoS() const { return !isSoA; } 55 bool isSegmentStart(Level l) const { return l == lvlRange.first; } 56 bool inSegment(Level l) const { 57 return l >= lvlRange.first && l < lvlRange.second; 58 } 59 }; 60 61 /// A simple wrapper to encode a bitset of (at most 64) levels, currently used 62 /// by `sparse_tensor.iterate` operation for the set of levels on which the 63 /// coordinates should be loaded. 64 class I64BitSet { 65 uint64_t storage = 0; 66 67 public: 68 using const_set_bits_iterator = llvm::const_set_bits_iterator_impl<I64BitSet>; 69 const_set_bits_iterator begin() const { 70 return const_set_bits_iterator(*this); 71 } 72 const_set_bits_iterator end() const { 73 return const_set_bits_iterator(*this, -1); 74 } 75 iterator_range<const_set_bits_iterator> bits() const { 76 return make_range(begin(), end()); 77 } 78 79 I64BitSet() = default; 80 explicit I64BitSet(uint64_t bits) : storage(bits) {} 81 operator uint64_t() const { return storage; } 82 83 I64BitSet &set(unsigned i) { 84 assert(i < 64); 85 storage |= static_cast<uint64_t>(0x01u) << i; 86 return *this; 87 } 88 89 I64BitSet &operator|=(I64BitSet lhs) { 90 storage |= static_cast<uint64_t>(lhs); 91 return *this; 92 } 93 94 I64BitSet &lshift(unsigned offset) { 95 storage = storage << offset; 96 return *this; 97 } 98 99 bool isSubSetOf(const I64BitSet p) const { 100 I64BitSet tmp = *this; 101 tmp |= p; 102 return tmp == p; 103 } 104 105 // Needed by `llvm::const_set_bits_iterator_impl`. 106 int find_first() const { return min(); } 107 int find_next(unsigned prev) const { 108 if (prev >= max() - 1) 109 return -1; 110 111 uint64_t b = storage >> (prev + static_cast<int64_t>(1)); 112 assert(b != 0); 113 114 return llvm::countr_zero(b) + prev + static_cast<int64_t>(1); 115 } 116 117 bool operator[](unsigned i) const { 118 assert(i < 64); 119 return (storage & (static_cast<int64_t>(1) << i)) != 0; 120 } 121 unsigned min() const { 122 unsigned m = llvm::countr_zero(storage); 123 return m == 64 ? -1 : m; 124 } 125 unsigned max() const { return 64 - llvm::countl_zero(storage); } 126 unsigned count() const { return llvm::popcount(storage); } 127 bool empty() const { return storage == 0; } 128 }; 129 130 } // namespace sparse_tensor 131 } // namespace mlir 132 133 //===----------------------------------------------------------------------===// 134 // TableGen-defined classes 135 //===----------------------------------------------------------------------===// 136 137 #define GET_ATTRDEF_CLASSES 138 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.h.inc" 139 140 #define GET_ATTRDEF_CLASSES 141 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.h.inc" 142 143 #define GET_TYPEDEF_CLASSES 144 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.h.inc" 145 146 #define GET_OP_CLASSES 147 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.h.inc" 148 149 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.h.inc" 150 151 //===----------------------------------------------------------------------===// 152 // Additional convenience methods. 153 //===----------------------------------------------------------------------===// 154 155 namespace mlir { 156 namespace sparse_tensor { 157 158 /// Convenience method to abbreviate casting `getType()`. 159 template <typename T> 160 inline RankedTensorType getRankedTensorType(T &&t) { 161 assert(static_cast<bool>(std::forward<T>(t)) && 162 "getRankedTensorType got null argument"); 163 return dyn_cast<RankedTensorType>(std::forward<T>(t).getType()); 164 } 165 166 /// Convenience method to abbreviate casting `getType()`. 167 template <typename T> 168 inline MemRefType getMemRefType(T &&t) { 169 assert(static_cast<bool>(std::forward<T>(t)) && 170 "getMemRefType got null argument"); 171 return cast<MemRefType>(std::forward<T>(t).getType()); 172 } 173 174 /// Convenience method to get a sparse encoding attribute from a type. 175 /// Returns null-attribute for any type without an encoding. 176 SparseTensorEncodingAttr getSparseTensorEncoding(Type type); 177 178 /// Returns true iff the type range has any sparse tensor type. 179 inline bool hasAnySparseType(TypeRange types) { 180 return llvm::any_of(types, [](Type type) { 181 return getSparseTensorEncoding(type) != nullptr; 182 }); 183 } 184 185 /// Returns true iff MLIR operand has any sparse operand. 186 inline bool hasAnySparseOperand(Operation *op) { 187 return hasAnySparseType(op->getOperands().getTypes()); 188 } 189 190 /// Returns true iff MLIR operand has any sparse result. 191 inline bool hasAnySparseResult(Operation *op) { 192 return hasAnySparseType(op->getResults().getTypes()); 193 } 194 195 /// Returns true iff MLIR operand has any sparse operand or result. 196 inline bool hasAnySparseOperandOrResult(Operation *op) { 197 return hasAnySparseOperand(op) || hasAnySparseResult(op); 198 } 199 200 /// Returns true iff MLIR operation has any sparse tensor with non-identity 201 /// dim2lvl maps. 202 bool hasAnyNonIdentityOperandsOrResults(Operation *op); 203 204 // 205 // Inference. 206 // 207 208 /// Given the dimToLvl map, infers the lvlToDim map, or returns 209 /// empty Affine map when inference fails. 210 AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context); 211 212 /// Returns the lvlToDim map for the given dimToLvl map specific 213 /// to the block sparse cases. 214 /// Asserts on failure (so only use when known to succeed). 215 AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context); 216 217 /// Given the dimToLvl map, returns the block sizes in a vector. 218 /// For instance, a 2x3 block will return [2, 3]. Unblocked dimension i 219 /// will return 0, and i floordiv 1, i mod 1 will return 1. Therefore, 220 /// the example below will return [0, 1]. 221 /// map = ( i, j ) -> 222 /// ( i : dense, 223 /// j floordiv 1 : compressed, 224 /// j mod 1 : dense 225 /// ) 226 /// Only valid block sparsity will be accepted. 227 SmallVector<unsigned> getBlockSize(AffineMap dimToLvl); 228 229 /// Given the dimToLvl map, returns if it's block sparsity. 230 bool isBlockSparsity(AffineMap dimToLvl); 231 232 // 233 // Reordering. 234 // 235 236 /// Convenience method to translate the given level to the corresponding 237 /// dimension. 238 /// Requires: `enc` has a permuted dim2lvl map and `0 <= l < lvlRank`. 239 Dimension toDim(SparseTensorEncodingAttr enc, Level l); 240 241 /// Convenience method to translate the given dimension to the corresponding 242 /// level. 243 /// Requires: `enc` has a permuted dim2lvl map and `0 <= d < dimRank`. 244 Level toLvl(SparseTensorEncodingAttr enc, Dimension d); 245 246 } // namespace sparse_tensor 247 } // namespace mlir 248 249 #endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_ 250