xref: /llvm-project/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp (revision dc4cfdbb8f9f665c1699e6289b6edfbc8d1bb443)
1 //===- SparseTensorRuntime.cpp - SparseTensor runtime support lib ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library for
10 // manipulating sparse tensors from MLIR.  More specifically, it provides
11 // C-API wrappers so that MLIR-generated code can call into the C++ runtime
12 // support library.  The functionality provided in this library is meant
13 // to simplify benchmarking, testing, and debugging of MLIR code operating
14 // on sparse tensors.  However, the provided functionality is **not**
15 // part of core MLIR itself.
16 //
17 // The following memory-resident sparse storage schemes are supported:
18 //
19 // (a) A coordinate scheme for temporarily storing and lexicographically
20 //     sorting a sparse tensor by coordinate (SparseTensorCOO).
21 //
22 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
23 //     per-dimension sparse/dense annnotations together with a dimension
24 //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
25 //
26 // The following external formats are supported:
27 //
28 // (1) Matrix Market Exchange (MME): *.mtx
29 //     https://math.nist.gov/MatrixMarket/formats.html
30 //
31 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
32 //     http://frostt.io/tensors/file-formats.html
33 //
34 // Two public APIs are supported:
35 //
36 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
37 //     tensors. These methods should be used exclusively by MLIR
38 //     compiler-generated code.
39 //
40 // (II) Methods that accept C-style data structures to interact with sparse
41 //      tensors. These methods can be used by any external runtime that wants
42 //      to interact with MLIR compiler-generated code.
43 //
44 // In both cases (I) and (II), the SparseTensorStorage format is externally
45 // only visible as an opaque pointer.
46 //
47 //===----------------------------------------------------------------------===//
48 
49 #include "mlir/ExecutionEngine/SparseTensorRuntime.h"
50 
51 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
52 
53 #include "mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h"
54 #include "mlir/ExecutionEngine/SparseTensor/COO.h"
55 #include "mlir/ExecutionEngine/SparseTensor/File.h"
56 #include "mlir/ExecutionEngine/SparseTensor/Storage.h"
57 
58 #include <cstring>
59 #include <numeric>
60 
61 using namespace mlir::sparse_tensor;
62 
63 //===----------------------------------------------------------------------===//
64 //
65 // Utilities for manipulating `StridedMemRefType`.
66 //
67 //===----------------------------------------------------------------------===//
68 
69 namespace {
70 
71 #define ASSERT_NO_STRIDE(MEMREF)                                               \
72   do {                                                                         \
73     assert((MEMREF) && "Memref is nullptr");                                   \
74     assert(((MEMREF)->strides[0] == 1) && "Memref has non-trivial stride");    \
75   } while (false)
76 
77 #define MEMREF_GET_USIZE(MEMREF)                                               \
78   detail::checkOverflowCast<uint64_t>((MEMREF)->sizes[0])
79 
80 #define ASSERT_USIZE_EQ(MEMREF, SZ)                                            \
81   assert(detail::safelyEQ(MEMREF_GET_USIZE(MEMREF), (SZ)) &&                   \
82          "Memref size mismatch")
83 
84 #define MEMREF_GET_PAYLOAD(MEMREF) ((MEMREF)->data + (MEMREF)->offset)
85 
86 /// Initializes the memref with the provided size and data pointer. This
87 /// is designed for functions which want to "return" a memref that aliases
88 /// into memory owned by some other object (e.g., `SparseTensorStorage`),
89 /// without doing any actual copying.  (The "return" is in scarequotes
90 /// because the `_mlir_ciface_` calling convention migrates any returned
91 /// memrefs into an out-parameter passed before all the other function
92 /// parameters.)
93 template <typename DataSizeT, typename T>
aliasIntoMemref(DataSizeT size,T * data,StridedMemRefType<T,1> & ref)94 static inline void aliasIntoMemref(DataSizeT size, T *data,
95                                    StridedMemRefType<T, 1> &ref) {
96   ref.basePtr = ref.data = data;
97   ref.offset = 0;
98   using MemrefSizeT = std::remove_reference_t<decltype(ref.sizes[0])>;
99   ref.sizes[0] = detail::checkOverflowCast<MemrefSizeT>(size);
100   ref.strides[0] = 1;
101 }
102 
103 } // anonymous namespace
104 
105 extern "C" {
106 
107 //===----------------------------------------------------------------------===//
108 //
109 // Public functions which operate on MLIR buffers (memrefs) to interact
110 // with sparse tensors (which are only visible as opaque pointers externally).
111 //
112 //===----------------------------------------------------------------------===//
113 
114 #define CASE(p, c, v, P, C, V)                                                 \
115   if (posTp == (p) && crdTp == (c) && valTp == (v)) {                          \
116     switch (action) {                                                          \
117     case Action::kEmpty: {                                                     \
118       return SparseTensorStorage<P, C, V>::newEmpty(                           \
119           dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim);   \
120     }                                                                          \
121     case Action::kFromReader: {                                                \
122       assert(ptr && "Received nullptr for SparseTensorReader object");         \
123       SparseTensorReader &reader = *static_cast<SparseTensorReader *>(ptr);    \
124       return static_cast<void *>(reader.readSparseTensor<P, C, V>(             \
125           lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim));                     \
126     }                                                                          \
127     case Action::kPack: {                                                      \
128       assert(ptr && "Received nullptr for SparseTensorStorage object");        \
129       intptr_t *buffers = static_cast<intptr_t *>(ptr);                        \
130       return SparseTensorStorage<P, C, V>::newFromBuffers(                     \
131           dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim,    \
132           dimRank, buffers);                                                   \
133     }                                                                          \
134     case Action::kSortCOOInPlace: {                                            \
135       assert(ptr && "Received nullptr for SparseTensorStorage object");        \
136       auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr);        \
137       tensor.sortInPlace();                                                    \
138       return ptr;                                                              \
139     }                                                                          \
140     }                                                                          \
141     fprintf(stderr, "unknown action %d\n", static_cast<uint32_t>(action));     \
142     exit(1);                                                                   \
143   }
144 
145 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
146 
147 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
148 // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
149 // that this file cannot get out of sync with its header.
150 static_assert(std::is_same<index_type, uint64_t>::value,
151               "Expected index_type == uint64_t");
152 
153 // The Swiss-army-knife for sparse tensor creation.
_mlir_ciface_newSparseTensor(StridedMemRefType<index_type,1> * dimSizesRef,StridedMemRefType<index_type,1> * lvlSizesRef,StridedMemRefType<LevelType,1> * lvlTypesRef,StridedMemRefType<index_type,1> * dim2lvlRef,StridedMemRefType<index_type,1> * lvl2dimRef,OverheadType posTp,OverheadType crdTp,PrimaryType valTp,Action action,void * ptr)154 void *_mlir_ciface_newSparseTensor( // NOLINT
155     StridedMemRefType<index_type, 1> *dimSizesRef,
156     StridedMemRefType<index_type, 1> *lvlSizesRef,
157     StridedMemRefType<LevelType, 1> *lvlTypesRef,
158     StridedMemRefType<index_type, 1> *dim2lvlRef,
159     StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
160     OverheadType crdTp, PrimaryType valTp, Action action, void *ptr) {
161   ASSERT_NO_STRIDE(dimSizesRef);
162   ASSERT_NO_STRIDE(lvlSizesRef);
163   ASSERT_NO_STRIDE(lvlTypesRef);
164   ASSERT_NO_STRIDE(dim2lvlRef);
165   ASSERT_NO_STRIDE(lvl2dimRef);
166   const uint64_t dimRank = MEMREF_GET_USIZE(dimSizesRef);
167   const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
168   ASSERT_USIZE_EQ(lvlTypesRef, lvlRank);
169   ASSERT_USIZE_EQ(dim2lvlRef, lvlRank);
170   ASSERT_USIZE_EQ(lvl2dimRef, dimRank);
171   const index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef);
172   const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
173   const LevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
174   const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
175   const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
176 
177   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
178   // This is safe because of the static_assert above.
179   if (posTp == OverheadType::kIndex)
180     posTp = OverheadType::kU64;
181   if (crdTp == OverheadType::kIndex)
182     crdTp = OverheadType::kU64;
183 
184   // Double matrices with all combinations of overhead storage.
185   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
186        uint64_t, double);
187   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
188        uint32_t, double);
189   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
190        uint16_t, double);
191   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
192        uint8_t, double);
193   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
194        uint64_t, double);
195   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
196        uint32_t, double);
197   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
198        uint16_t, double);
199   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
200        uint8_t, double);
201   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
202        uint64_t, double);
203   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
204        uint32_t, double);
205   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
206        uint16_t, double);
207   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
208        uint8_t, double);
209   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
210        uint64_t, double);
211   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
212        uint32_t, double);
213   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
214        uint16_t, double);
215   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
216        uint8_t, double);
217 
218   // Float matrices with all combinations of overhead storage.
219   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
220        uint64_t, float);
221   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
222        uint32_t, float);
223   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
224        uint16_t, float);
225   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
226        uint8_t, float);
227   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
228        uint64_t, float);
229   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
230        uint32_t, float);
231   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
232        uint16_t, float);
233   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
234        uint8_t, float);
235   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
236        uint64_t, float);
237   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
238        uint32_t, float);
239   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
240        uint16_t, float);
241   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
242        uint8_t, float);
243   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
244        uint64_t, float);
245   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
246        uint32_t, float);
247   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
248        uint16_t, float);
249   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
250        uint8_t, float);
251 
252   // Two-byte floats with both overheads of the same type.
253   CASE_SECSAME(OverheadType::kU64, PrimaryType::kF16, uint64_t, f16);
254   CASE_SECSAME(OverheadType::kU64, PrimaryType::kBF16, uint64_t, bf16);
255   CASE_SECSAME(OverheadType::kU32, PrimaryType::kF16, uint32_t, f16);
256   CASE_SECSAME(OverheadType::kU32, PrimaryType::kBF16, uint32_t, bf16);
257   CASE_SECSAME(OverheadType::kU16, PrimaryType::kF16, uint16_t, f16);
258   CASE_SECSAME(OverheadType::kU16, PrimaryType::kBF16, uint16_t, bf16);
259   CASE_SECSAME(OverheadType::kU8, PrimaryType::kF16, uint8_t, f16);
260   CASE_SECSAME(OverheadType::kU8, PrimaryType::kBF16, uint8_t, bf16);
261 
262   // Integral matrices with both overheads of the same type.
263   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
264   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
265   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
266   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
267   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI64, uint32_t, int64_t);
268   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
269   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
270   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
271   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI64, uint16_t, int64_t);
272   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
273   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
274   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
275   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI64, uint8_t, int64_t);
276   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
277   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
278   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
279 
280   // Complex matrices with wide overhead.
281   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
282   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
283 
284   // Unsupported case (add above if needed).
285   fprintf(stderr, "unsupported combination of types: <P=%d, C=%d, V=%d>\n",
286           static_cast<int>(posTp), static_cast<int>(crdTp),
287           static_cast<int>(valTp));
288   exit(1);
289 }
290 #undef CASE
291 #undef CASE_SECSAME
292 
293 #define IMPL_SPARSEVALUES(VNAME, V)                                            \
294   void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref,          \
295                                         void *tensor) {                        \
296     assert(ref &&tensor);                                                      \
297     std::vector<V> *v;                                                         \
298     static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v);             \
299     assert(v);                                                                 \
300     aliasIntoMemref(v->size(), v->data(), *ref);                               \
301   }
302 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES)
303 #undef IMPL_SPARSEVALUES
304 
305 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
306   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
307                            index_type lvl) {                                   \
308     assert(ref &&tensor);                                                      \
309     std::vector<TYPE> *v;                                                      \
310     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, lvl);              \
311     assert(v);                                                                 \
312     aliasIntoMemref(v->size(), v->data(), *ref);                               \
313   }
314 
315 #define IMPL_SPARSEPOSITIONS(PNAME, P)                                         \
316   IMPL_GETOVERHEAD(sparsePositions##PNAME, P, getPositions)
MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS)317 MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS)
318 #undef IMPL_SPARSEPOSITIONS
319 
320 #define IMPL_SPARSECOORDINATES(CNAME, C)                                       \
321   IMPL_GETOVERHEAD(sparseCoordinates##CNAME, C, getCoordinates)
322 MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
323 #undef IMPL_SPARSECOORDINATES
324 
325 #define IMPL_SPARSECOORDINATESBUFFER(CNAME, C)                                 \
326   IMPL_GETOVERHEAD(sparseCoordinatesBuffer##CNAME, C, getCoordinatesBuffer)
327 MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATESBUFFER)
328 #undef IMPL_SPARSECOORDINATESBUFFER
329 
330 #undef IMPL_GETOVERHEAD
331 
332 #define IMPL_LEXINSERT(VNAME, V)                                               \
333   void _mlir_ciface_lexInsert##VNAME(                                          \
334       void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef,                 \
335       StridedMemRefType<V, 0> *vref) {                                         \
336     assert(t &&vref);                                                          \
337     auto &tensor = *static_cast<SparseTensorStorageBase *>(t);                 \
338     ASSERT_NO_STRIDE(lvlCoordsRef);                                            \
339     index_type *lvlCoords = MEMREF_GET_PAYLOAD(lvlCoordsRef);                  \
340     assert(lvlCoords);                                                         \
341     V *value = MEMREF_GET_PAYLOAD(vref);                                       \
342     tensor.lexInsert(lvlCoords, *value);                                       \
343   }
344 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
345 #undef IMPL_LEXINSERT
346 
347 #define IMPL_EXPINSERT(VNAME, V)                                               \
348   void _mlir_ciface_expInsert##VNAME(                                          \
349       void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef,                 \
350       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
351       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
352     assert(t);                                                                 \
353     auto &tensor = *static_cast<SparseTensorStorageBase *>(t);                 \
354     ASSERT_NO_STRIDE(lvlCoordsRef);                                            \
355     ASSERT_NO_STRIDE(vref);                                                    \
356     ASSERT_NO_STRIDE(fref);                                                    \
357     ASSERT_NO_STRIDE(aref);                                                    \
358     ASSERT_USIZE_EQ(vref, MEMREF_GET_USIZE(fref));                             \
359     index_type *lvlCoords = MEMREF_GET_PAYLOAD(lvlCoordsRef);                  \
360     V *values = MEMREF_GET_PAYLOAD(vref);                                      \
361     bool *filled = MEMREF_GET_PAYLOAD(fref);                                   \
362     index_type *added = MEMREF_GET_PAYLOAD(aref);                              \
363     uint64_t expsz = vref->sizes[0];                                           \
364     tensor.expInsert(lvlCoords, values, filled, added, count, expsz);          \
365   }
366 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)
367 #undef IMPL_EXPINSERT
368 
369 void *_mlir_ciface_createCheckedSparseTensorReader(
370     char *filename, StridedMemRefType<index_type, 1> *dimShapeRef,
371     PrimaryType valTp) {
372   ASSERT_NO_STRIDE(dimShapeRef);
373   const uint64_t dimRank = MEMREF_GET_USIZE(dimShapeRef);
374   const index_type *dimShape = MEMREF_GET_PAYLOAD(dimShapeRef);
375   auto *reader = SparseTensorReader::create(filename, dimRank, dimShape, valTp);
376   return static_cast<void *>(reader);
377 }
378 
_mlir_ciface_getSparseTensorReaderDimSizes(StridedMemRefType<index_type,1> * out,void * p)379 void _mlir_ciface_getSparseTensorReaderDimSizes(
380     StridedMemRefType<index_type, 1> *out, void *p) {
381   assert(out && p);
382   SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
383   auto *dimSizes = const_cast<uint64_t *>(reader.getDimSizes());
384   aliasIntoMemref(reader.getRank(), dimSizes, *out);
385 }
386 
387 #define IMPL_GETNEXT(VNAME, V, CNAME, C)                                       \
388   bool _mlir_ciface_getSparseTensorReaderReadToBuffers##CNAME##VNAME(          \
389       void *p, StridedMemRefType<index_type, 1> *dim2lvlRef,                   \
390       StridedMemRefType<index_type, 1> *lvl2dimRef,                            \
391       StridedMemRefType<C, 1> *cref, StridedMemRefType<V, 1> *vref) {          \
392     assert(p);                                                                 \
393     auto &reader = *static_cast<SparseTensorReader *>(p);                      \
394     ASSERT_NO_STRIDE(dim2lvlRef);                                              \
395     ASSERT_NO_STRIDE(lvl2dimRef);                                              \
396     ASSERT_NO_STRIDE(cref);                                                    \
397     ASSERT_NO_STRIDE(vref);                                                    \
398     const uint64_t dimRank = reader.getRank();                                 \
399     const uint64_t lvlRank = MEMREF_GET_USIZE(dim2lvlRef);                     \
400     const uint64_t cSize = MEMREF_GET_USIZE(cref);                             \
401     const uint64_t vSize = MEMREF_GET_USIZE(vref);                             \
402     ASSERT_USIZE_EQ(lvl2dimRef, dimRank);                                      \
403     assert(cSize >= lvlRank * reader.getNSE());                                \
404     assert(vSize >= reader.getNSE());                                          \
405     (void)dimRank;                                                             \
406     (void)cSize;                                                               \
407     (void)vSize;                                                               \
408     index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);                      \
409     index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);                      \
410     C *lvlCoordinates = MEMREF_GET_PAYLOAD(cref);                              \
411     V *values = MEMREF_GET_PAYLOAD(vref);                                      \
412     return reader.readToBuffers<C, V>(lvlRank, dim2lvl, lvl2dim,               \
413                                       lvlCoordinates, values);                 \
414   }
MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)415 MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)
416 #undef IMPL_GETNEXT
417 
418 void _mlir_ciface_outSparseTensorWriterMetaData(
419     void *p, index_type dimRank, index_type nse,
420     StridedMemRefType<index_type, 1> *dimSizesRef) {
421   assert(p);
422   ASSERT_NO_STRIDE(dimSizesRef);
423   assert(dimRank != 0);
424   index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef);
425   std::ostream &file = *static_cast<std::ostream *>(p);
426   file << dimRank << " " << nse << '\n';
427   for (index_type d = 0; d < dimRank - 1; d++)
428     file << dimSizes[d] << " ";
429   file << dimSizes[dimRank - 1] << '\n';
430 }
431 
432 #define IMPL_OUTNEXT(VNAME, V)                                                 \
433   void _mlir_ciface_outSparseTensorWriterNext##VNAME(                          \
434       void *p, index_type dimRank,                                             \
435       StridedMemRefType<index_type, 1> *dimCoordsRef,                          \
436       StridedMemRefType<V, 0> *vref) {                                         \
437     assert(p &&vref);                                                          \
438     ASSERT_NO_STRIDE(dimCoordsRef);                                            \
439     const index_type *dimCoords = MEMREF_GET_PAYLOAD(dimCoordsRef);            \
440     std::ostream &file = *static_cast<std::ostream *>(p);                      \
441     for (index_type d = 0; d < dimRank; d++)                                   \
442       file << (dimCoords[d] + 1) << " ";                                       \
443     V *value = MEMREF_GET_PAYLOAD(vref);                                       \
444     file << *value << '\n';                                                    \
445   }
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT)446 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT)
447 #undef IMPL_OUTNEXT
448 
449 //===----------------------------------------------------------------------===//
450 //
451 // Public functions which accept only C-style data structures to interact
452 // with sparse tensors (which are only visible as opaque pointers externally).
453 //
454 //===----------------------------------------------------------------------===//
455 
456 index_type sparseLvlSize(void *tensor, index_type l) {
457   return static_cast<SparseTensorStorageBase *>(tensor)->getLvlSize(l);
458 }
459 
sparseDimSize(void * tensor,index_type d)460 index_type sparseDimSize(void *tensor, index_type d) {
461   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
462 }
463 
endLexInsert(void * tensor)464 void endLexInsert(void *tensor) {
465   return static_cast<SparseTensorStorageBase *>(tensor)->endLexInsert();
466 }
467 
delSparseTensor(void * tensor)468 void delSparseTensor(void *tensor) {
469   delete static_cast<SparseTensorStorageBase *>(tensor);
470 }
471 
getTensorFilename(index_type id)472 char *getTensorFilename(index_type id) {
473   constexpr size_t bufSize = 80;
474   char var[bufSize];
475   snprintf(var, bufSize, "TENSOR%" PRIu64, id);
476   char *env = getenv(var);
477   if (!env) {
478     fprintf(stderr, "Environment variable %s is not set\n", var);
479     exit(1);
480   }
481   return env;
482 }
483 
getSparseTensorReaderNSE(void * p)484 index_type getSparseTensorReaderNSE(void *p) {
485   return static_cast<SparseTensorReader *>(p)->getNSE();
486 }
487 
delSparseTensorReader(void * p)488 void delSparseTensorReader(void *p) {
489   delete static_cast<SparseTensorReader *>(p);
490 }
491 
createSparseTensorWriter(char * filename)492 void *createSparseTensorWriter(char *filename) {
493   std::ostream *file =
494       (filename[0] == 0) ? &std::cout : new std::ofstream(filename);
495   *file << "# extended FROSTT format\n";
496   return static_cast<void *>(file);
497 }
498 
delSparseTensorWriter(void * p)499 void delSparseTensorWriter(void *p) {
500   std::ostream *file = static_cast<std::ostream *>(p);
501   file->flush();
502   assert(file->good());
503   if (file != &std::cout)
504     delete file;
505 }
506 
507 } // extern "C"
508 
509 #undef MEMREF_GET_PAYLOAD
510 #undef ASSERT_USIZE_EQ
511 #undef MEMREF_GET_USIZE
512 #undef ASSERT_NO_STRIDE
513 
514 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
515