xref: /llvm-project/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h (revision 52b69aa32f5280ce600fbfea1c16a6f17a979c4d)
1 //===- SparseTensorStorageLayout.h ------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for the sparse memory layout.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORSTORAGELAYOUT_H_
14 #define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORSTORAGELAYOUT_H_
15 
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
18 
19 namespace mlir {
20 namespace sparse_tensor {
21 
22 ///===----------------------------------------------------------------------===//
23 /// The sparse tensor storage scheme for a tensor is organized as a single
24 /// compound type with the following fields. Note that every memref with `?`
25 /// size actually behaves as a "vector", i.e. the stored size is the capacity
26 /// and the used size resides in the storage_specifier struct.
27 ///
28 /// struct {
29 ///   ; per-level l:
30 ///   ;  if dense:
31 ///        <nothing>
32 ///   ;  if compressed:
33 ///        memref<[batch] x ? x pos>  positions   ; positions for level l
34 ///        memref<[batch] x ? x crd>  coordinates ; coordinates for level l
35 ///   ;  if loose-[batch] x compressed:
36 ///        memref<[batch] x ? x pos>  positions   ; lo/hi pos pairs for level l
37 ///        memref<[batch] x ? x crd>  coordinates ; coordinates for level l
38 ///   ;  if singleton/2-out-of-4:
39 ///        memref<[batch] x ? x crd>  coordinates ; coordinates for level l
40 ///
41 ///   memref<[batch] x ? x eltType> values        ; values
42 ///
43 ///   struct sparse_tensor.storage_specifier {
44 ///     array<rank x int> lvlSizes    ; sizes/cardinalities for each level
45 ///     // TODO: memSizes need to be expanded to array<[batch] x n x int> to
46 ///     // support different sizes for different batches. At the moment, we
47 ///     // assume that every batch occupies the same memory size.
48 ///     array<n x int> memSizes       ; sizes/lengths for each data memref
49 ///   }
50 /// };
51 ///
52 /// In addition, for a "trailing COO region", defined as a compressed level
53 /// followed by one or more singleton levels, the default SOA storage that
54 /// is inherent to the TACO format is optimized into an AOS storage where
55 /// all coordinates of a stored element appear consecutively.  In such cases,
56 /// a special operation (sparse_tensor.coordinates_buffer) must be used to
57 /// access the AOS coordinates array. In the code below, the method
58 /// `getCOOStart` is used to find the start of the "trailing COO region".
59 ///
60 /// If the sparse tensor is a slice (produced by `tensor.extract_slice`
61 /// operation), instead of allocating a new sparse tensor for it, it reuses the
62 /// same sets of MemRefs but attaching a additional set of slicing-metadata for
63 /// per-dimension slice offset and stride.
64 ///
65 /// Examples.
66 ///
67 /// #CSR storage of 2-dim matrix yields
68 ///  memref<?xindex>                           ; positions-1
69 ///  memref<?xindex>                           ; coordinates-1
70 ///  memref<?xf64>                             ; values
71 ///  struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes
72 ///
73 /// #COO storage of 2-dim matrix yields
74 ///  memref<?xindex>,                          ; positions-0, essentially [0,sz]
75 ///  memref<?xindex>                           ; AOS coordinates storage
76 ///  memref<?xf64>                             ; values
77 ///  struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes
78 ///
79 /// Slice on #COO storage of 2-dim matrix yields
80 ///  ;; Inherited from the original sparse tensors
81 ///  memref<?xindex>,                          ; positions-0, essentially [0,sz]
82 ///  memref<?xindex>                           ; AOS coordinates storage
83 ///  memref<?xf64>                             ; values
84 ///  struct<(array<2 x i64>, array<3 x i64>,   ; lvl0, lvl1, 3xsizes
85 ///  ;; Extra slicing-metadata
86 ///          array<2 x i64>, array<2 x i64>)>) ; dim offset, dim stride.
87 ///
88 ///===----------------------------------------------------------------------===//
89 
90 enum class SparseTensorFieldKind : uint32_t {
91   StorageSpec = 0,
92   PosMemRef = static_cast<uint32_t>(StorageSpecifierKind::PosMemSize),
93   CrdMemRef = static_cast<uint32_t>(StorageSpecifierKind::CrdMemSize),
94   ValMemRef = static_cast<uint32_t>(StorageSpecifierKind::ValMemSize)
95 };
96 
toSpecifierKind(SparseTensorFieldKind kind)97 inline StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind) {
98   assert(kind != SparseTensorFieldKind::StorageSpec);
99   return static_cast<StorageSpecifierKind>(kind);
100 }
101 
toFieldKind(StorageSpecifierKind kind)102 inline SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind) {
103   assert(kind != StorageSpecifierKind::LvlSize);
104   return static_cast<SparseTensorFieldKind>(kind);
105 }
106 
107 /// The type of field indices.  This alias is to help code be more
108 /// self-documenting; unfortunately it is not type-checked, so it only
109 /// provides documentation rather than doing anything to prevent mixups.
110 using FieldIndex = unsigned;
111 
112 /// Provides methods to access fields of a sparse tensor with the given
113 /// encoding.
114 class StorageLayout {
115 public:
StorageLayout(const SparseTensorType & stt)116   explicit StorageLayout(const SparseTensorType &stt)
117       : StorageLayout(stt.getEncoding()) {}
StorageLayout(SparseTensorEncodingAttr enc)118   explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) {
119     assert(enc);
120   }
121 
122   /// For each field that will be allocated for the given sparse tensor
123   /// encoding, calls the callback with the corresponding field index,
124   /// field kind, level, and level-type (the last two are only for level
125   /// memrefs).  The field index always starts with zero and increments
126   /// by one between each callback invocation.  Ideally, all other methods
127   /// should rely on this function to query a sparse tensor fields instead
128   /// of relying on ad-hoc index computation.
129   void foreachField(
130       llvm::function_ref<bool(
131           FieldIndex /*fieldIdx*/, SparseTensorFieldKind /*fieldKind*/,
132           Level /*lvl (if applicable)*/, LevelType /*LT (if applicable)*/)>)
133       const;
134 
135   /// Gets the field index for required field.
getMemRefFieldIndex(SparseTensorFieldKind kind,std::optional<Level> lvl)136   FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind,
137                                  std::optional<Level> lvl) const {
138     return getFieldIndexAndStride(kind, lvl).first;
139   }
140 
141   /// Gets the total number of fields for the given sparse tensor encoding.
142   unsigned getNumFields() const;
143 
144   /// Gets the total number of data fields (coordinate arrays, position
145   /// arrays, and a value array) for the given sparse tensor encoding.
146   unsigned getNumDataFields() const;
147 
148   std::pair<FieldIndex, unsigned>
149   getFieldIndexAndStride(SparseTensorFieldKind kind,
150                          std::optional<Level> lvl) const;
151 
152 private:
153   const SparseTensorEncodingAttr enc;
154 };
155 
156 //
157 // Wrapper functions to invoke StorageLayout-related method.
158 //
159 
getNumFieldsFromEncoding(SparseTensorEncodingAttr enc)160 inline unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
161   return StorageLayout(enc).getNumFields();
162 }
163 
getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc)164 inline unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) {
165   return StorageLayout(enc).getNumDataFields();
166 }
167 
foreachFieldInSparseTensor(SparseTensorEncodingAttr enc,llvm::function_ref<bool (FieldIndex,SparseTensorFieldKind,Level,LevelType)> callback)168 inline void foreachFieldInSparseTensor(
169     SparseTensorEncodingAttr enc,
170     llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
171                             LevelType)>
172         callback) {
173   return StorageLayout(enc).foreachField(callback);
174 }
175 
176 void foreachFieldAndTypeInSparseTensor(
177     SparseTensorType,
178     llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
179                             LevelType)>);
180 
181 } // namespace sparse_tensor
182 } // namespace mlir
183 
184 #endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORSTORAGELAYOUT_H_
185