xref: /llvm-project/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp (revision dc4cfdbb8f9f665c1699e6289b6edfbc8d1bb443)
162066929Swren romano //===- SparseTensorRuntime.cpp - SparseTensor runtime support lib ---------===//
262066929Swren romano //
362066929Swren romano // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
462066929Swren romano // See https://llvm.org/LICENSE.txt for license information.
562066929Swren romano // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
662066929Swren romano //
762066929Swren romano //===----------------------------------------------------------------------===//
862066929Swren romano //
962066929Swren romano // This file implements a light-weight runtime support library for
1062066929Swren romano // manipulating sparse tensors from MLIR.  More specifically, it provides
1162066929Swren romano // C-API wrappers so that MLIR-generated code can call into the C++ runtime
1262066929Swren romano // support library.  The functionality provided in this library is meant
1362066929Swren romano // to simplify benchmarking, testing, and debugging of MLIR code operating
1462066929Swren romano // on sparse tensors.  However, the provided functionality is **not**
1562066929Swren romano // part of core MLIR itself.
1662066929Swren romano //
1762066929Swren romano // The following memory-resident sparse storage schemes are supported:
1862066929Swren romano //
1962066929Swren romano // (a) A coordinate scheme for temporarily storing and lexicographically
2084cd51bbSwren romano //     sorting a sparse tensor by coordinate (SparseTensorCOO).
2162066929Swren romano //
2262066929Swren romano // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
2362066929Swren romano //     per-dimension sparse/dense annnotations together with a dimension
2462066929Swren romano //     ordering used by MLIR compiler-generated code (SparseTensorStorage).
2562066929Swren romano //
2662066929Swren romano // The following external formats are supported:
2762066929Swren romano //
2862066929Swren romano // (1) Matrix Market Exchange (MME): *.mtx
2962066929Swren romano //     https://math.nist.gov/MatrixMarket/formats.html
3062066929Swren romano //
3162066929Swren romano // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
3262066929Swren romano //     http://frostt.io/tensors/file-formats.html
3362066929Swren romano //
3462066929Swren romano // Two public APIs are supported:
3562066929Swren romano //
3662066929Swren romano // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
3762066929Swren romano //     tensors. These methods should be used exclusively by MLIR
3862066929Swren romano //     compiler-generated code.
3962066929Swren romano //
4062066929Swren romano // (II) Methods that accept C-style data structures to interact with sparse
4162066929Swren romano //      tensors. These methods can be used by any external runtime that wants
4262066929Swren romano //      to interact with MLIR compiler-generated code.
4362066929Swren romano //
4462066929Swren romano // In both cases (I) and (II), the SparseTensorStorage format is externally
4562066929Swren romano // only visible as an opaque pointer.
4662066929Swren romano //
4762066929Swren romano //===----------------------------------------------------------------------===//
4862066929Swren romano 
4962066929Swren romano #include "mlir/ExecutionEngine/SparseTensorRuntime.h"
5062066929Swren romano 
5162066929Swren romano #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
5262066929Swren romano 
53a3e48883Swren romano #include "mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h"
5462066929Swren romano #include "mlir/ExecutionEngine/SparseTensor/COO.h"
5562066929Swren romano #include "mlir/ExecutionEngine/SparseTensor/File.h"
5662066929Swren romano #include "mlir/ExecutionEngine/SparseTensor/Storage.h"
5762066929Swren romano 
58c518745bSwren romano #include <cstring>
5962066929Swren romano #include <numeric>
6062066929Swren romano 
6162066929Swren romano using namespace mlir::sparse_tensor;
6262066929Swren romano 
6362066929Swren romano //===----------------------------------------------------------------------===//
6462066929Swren romano //
659bd5bfc6SAart Bik // Utilities for manipulating `StridedMemRefType`.
6662066929Swren romano //
6762066929Swren romano //===----------------------------------------------------------------------===//
6862066929Swren romano 
6962066929Swren romano namespace {
7062066929Swren romano 
71a3e48883Swren romano #define ASSERT_NO_STRIDE(MEMREF)                                               \
72a3e48883Swren romano   do {                                                                         \
73a3e48883Swren romano     assert((MEMREF) && "Memref is nullptr");                                   \
74a3e48883Swren romano     assert(((MEMREF)->strides[0] == 1) && "Memref has non-trivial stride");    \
75a3e48883Swren romano   } while (false)
76a3e48883Swren romano 
77a3e48883Swren romano #define MEMREF_GET_USIZE(MEMREF)                                               \
78a3e48883Swren romano   detail::checkOverflowCast<uint64_t>((MEMREF)->sizes[0])
79a3e48883Swren romano 
80a3e48883Swren romano #define ASSERT_USIZE_EQ(MEMREF, SZ)                                            \
81a3e48883Swren romano   assert(detail::safelyEQ(MEMREF_GET_USIZE(MEMREF), (SZ)) &&                   \
82a3e48883Swren romano          "Memref size mismatch")
83a3e48883Swren romano 
84a3e48883Swren romano #define MEMREF_GET_PAYLOAD(MEMREF) ((MEMREF)->data + (MEMREF)->offset)
85a3e48883Swren romano 
862af2e4dbSwren romano /// Initializes the memref with the provided size and data pointer. This
872af2e4dbSwren romano /// is designed for functions which want to "return" a memref that aliases
882af2e4dbSwren romano /// into memory owned by some other object (e.g., `SparseTensorStorage`),
892af2e4dbSwren romano /// without doing any actual copying.  (The "return" is in scarequotes
902af2e4dbSwren romano /// because the `_mlir_ciface_` calling convention migrates any returned
912af2e4dbSwren romano /// memrefs into an out-parameter passed before all the other function
922af2e4dbSwren romano /// parameters.)
932af2e4dbSwren romano template <typename DataSizeT, typename T>
aliasIntoMemref(DataSizeT size,T * data,StridedMemRefType<T,1> & ref)942af2e4dbSwren romano static inline void aliasIntoMemref(DataSizeT size, T *data,
952af2e4dbSwren romano                                    StridedMemRefType<T, 1> &ref) {
962af2e4dbSwren romano   ref.basePtr = ref.data = data;
97a3e48883Swren romano   ref.offset = 0;
985a98dd67SKazu Hirata   using MemrefSizeT = std::remove_reference_t<decltype(ref.sizes[0])>;
992af2e4dbSwren romano   ref.sizes[0] = detail::checkOverflowCast<MemrefSizeT>(size);
100a3e48883Swren romano   ref.strides[0] = 1;
101a3e48883Swren romano }
102a3e48883Swren romano 
10362066929Swren romano } // anonymous namespace
10462066929Swren romano 
10562066929Swren romano extern "C" {
10662066929Swren romano 
10762066929Swren romano //===----------------------------------------------------------------------===//
10862066929Swren romano //
10962066929Swren romano // Public functions which operate on MLIR buffers (memrefs) to interact
11062066929Swren romano // with sparse tensors (which are only visible as opaque pointers externally).
11162066929Swren romano //
11262066929Swren romano //===----------------------------------------------------------------------===//
11362066929Swren romano 
11484cd51bbSwren romano #define CASE(p, c, v, P, C, V)                                                 \
11584cd51bbSwren romano   if (posTp == (p) && crdTp == (c) && valTp == (v)) {                          \
116c518745bSwren romano     switch (action) {                                                          \
1172045cca0SAart Bik     case Action::kEmpty: {                                                     \
11884cd51bbSwren romano       return SparseTensorStorage<P, C, V>::newEmpty(                           \
119f8ce460eSAart Bik           dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim);   \
12062066929Swren romano     }                                                                          \
121d392073fSAart Bik     case Action::kFromReader: {                                                \
122d392073fSAart Bik       assert(ptr && "Received nullptr for SparseTensorReader object");         \
123d392073fSAart Bik       SparseTensorReader &reader = *static_cast<SparseTensorReader *>(ptr);    \
124d392073fSAart Bik       return static_cast<void *>(reader.readSparseTensor<P, C, V>(             \
125d392073fSAart Bik           lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim));                     \
126d392073fSAart Bik     }                                                                          \
127fa6726e2SPeiming Liu     case Action::kPack: {                                                      \
128fa6726e2SPeiming Liu       assert(ptr && "Received nullptr for SparseTensorStorage object");        \
129fa6726e2SPeiming Liu       intptr_t *buffers = static_cast<intptr_t *>(ptr);                        \
1304daf86efSAart Bik       return SparseTensorStorage<P, C, V>::newFromBuffers(                     \
131db1d40f3SAart Bik           dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim,    \
132db1d40f3SAart Bik           dimRank, buffers);                                                   \
133fa6726e2SPeiming Liu     }                                                                          \
134f248d0b2SPeiming Liu     case Action::kSortCOOInPlace: {                                            \
135f248d0b2SPeiming Liu       assert(ptr && "Received nullptr for SparseTensorStorage object");        \
136f248d0b2SPeiming Liu       auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr);        \
137f248d0b2SPeiming Liu       tensor.sortInPlace();                                                    \
138f248d0b2SPeiming Liu       return ptr;                                                              \
139f248d0b2SPeiming Liu     }                                                                          \
140c518745bSwren romano     }                                                                          \
1411c2456d6SAart Bik     fprintf(stderr, "unknown action %d\n", static_cast<uint32_t>(action));     \
1421c2456d6SAart Bik     exit(1);                                                                   \
14362066929Swren romano   }
14462066929Swren romano 
14562066929Swren romano #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
14662066929Swren romano 
14762066929Swren romano // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
14862066929Swren romano // can safely rewrite kIndex to kU64.  We make this assertion to guarantee
14962066929Swren romano // that this file cannot get out of sync with its header.
15062066929Swren romano static_assert(std::is_same<index_type, uint64_t>::value,
15162066929Swren romano               "Expected index_type == uint64_t");
15262066929Swren romano 
153d3af6535SAart Bik // The Swiss-army-knife for sparse tensor creation.
_mlir_ciface_newSparseTensor(StridedMemRefType<index_type,1> * dimSizesRef,StridedMemRefType<index_type,1> * lvlSizesRef,StridedMemRefType<LevelType,1> * lvlTypesRef,StridedMemRefType<index_type,1> * dim2lvlRef,StridedMemRefType<index_type,1> * lvl2dimRef,OverheadType posTp,OverheadType crdTp,PrimaryType valTp,Action action,void * ptr)154c518745bSwren romano void *_mlir_ciface_newSparseTensor( // NOLINT
155c518745bSwren romano     StridedMemRefType<index_type, 1> *dimSizesRef,
156c518745bSwren romano     StridedMemRefType<index_type, 1> *lvlSizesRef,
1571944c4f7SAart Bik     StridedMemRefType<LevelType, 1> *lvlTypesRef,
158b7188d28SAart Bik     StridedMemRefType<index_type, 1> *dim2lvlRef,
159b7188d28SAart Bik     StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
16084cd51bbSwren romano     OverheadType crdTp, PrimaryType valTp, Action action, void *ptr) {
161a3e48883Swren romano   ASSERT_NO_STRIDE(dimSizesRef);
162a3e48883Swren romano   ASSERT_NO_STRIDE(lvlSizesRef);
163a3e48883Swren romano   ASSERT_NO_STRIDE(lvlTypesRef);
164a3e48883Swren romano   ASSERT_NO_STRIDE(dim2lvlRef);
165d3af6535SAart Bik   ASSERT_NO_STRIDE(lvl2dimRef);
166a3e48883Swren romano   const uint64_t dimRank = MEMREF_GET_USIZE(dimSizesRef);
167a3e48883Swren romano   const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
168a3e48883Swren romano   ASSERT_USIZE_EQ(lvlTypesRef, lvlRank);
169306f4c30SAart Bik   ASSERT_USIZE_EQ(dim2lvlRef, lvlRank);
170306f4c30SAart Bik   ASSERT_USIZE_EQ(lvl2dimRef, dimRank);
171a3e48883Swren romano   const index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef);
172a3e48883Swren romano   const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
1731944c4f7SAart Bik   const LevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
174a3e48883Swren romano   const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
175d3af6535SAart Bik   const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
17662066929Swren romano 
17762066929Swren romano   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
17862066929Swren romano   // This is safe because of the static_assert above.
17984cd51bbSwren romano   if (posTp == OverheadType::kIndex)
18084cd51bbSwren romano     posTp = OverheadType::kU64;
18184cd51bbSwren romano   if (crdTp == OverheadType::kIndex)
18284cd51bbSwren romano     crdTp = OverheadType::kU64;
18362066929Swren romano 
18462066929Swren romano   // Double matrices with all combinations of overhead storage.
18562066929Swren romano   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
18662066929Swren romano        uint64_t, double);
18762066929Swren romano   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
18862066929Swren romano        uint32_t, double);
18962066929Swren romano   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
19062066929Swren romano        uint16_t, double);
19162066929Swren romano   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
19262066929Swren romano        uint8_t, double);
19362066929Swren romano   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
19462066929Swren romano        uint64_t, double);
19562066929Swren romano   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
19662066929Swren romano        uint32_t, double);
19762066929Swren romano   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
19862066929Swren romano        uint16_t, double);
19962066929Swren romano   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
20062066929Swren romano        uint8_t, double);
20162066929Swren romano   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
20262066929Swren romano        uint64_t, double);
20362066929Swren romano   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
20462066929Swren romano        uint32_t, double);
20562066929Swren romano   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
20662066929Swren romano        uint16_t, double);
20762066929Swren romano   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
20862066929Swren romano        uint8_t, double);
20962066929Swren romano   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
21062066929Swren romano        uint64_t, double);
21162066929Swren romano   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
21262066929Swren romano        uint32_t, double);
21362066929Swren romano   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
21462066929Swren romano        uint16_t, double);
21562066929Swren romano   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
21662066929Swren romano        uint8_t, double);
21762066929Swren romano 
21862066929Swren romano   // Float matrices with all combinations of overhead storage.
21962066929Swren romano   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
22062066929Swren romano        uint64_t, float);
22162066929Swren romano   CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
22262066929Swren romano        uint32_t, float);
22362066929Swren romano   CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
22462066929Swren romano        uint16_t, float);
22562066929Swren romano   CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
22662066929Swren romano        uint8_t, float);
22762066929Swren romano   CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
22862066929Swren romano        uint64_t, float);
22962066929Swren romano   CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
23062066929Swren romano        uint32_t, float);
23162066929Swren romano   CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
23262066929Swren romano        uint16_t, float);
23362066929Swren romano   CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
23462066929Swren romano        uint8_t, float);
23562066929Swren romano   CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
23662066929Swren romano        uint64_t, float);
23762066929Swren romano   CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
23862066929Swren romano        uint32_t, float);
23962066929Swren romano   CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
24062066929Swren romano        uint16_t, float);
24162066929Swren romano   CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
24262066929Swren romano        uint8_t, float);
24362066929Swren romano   CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
24462066929Swren romano        uint64_t, float);
24562066929Swren romano   CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
24662066929Swren romano        uint32_t, float);
24762066929Swren romano   CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
24862066929Swren romano        uint16_t, float);
24962066929Swren romano   CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
25062066929Swren romano        uint8_t, float);
25162066929Swren romano 
25262066929Swren romano   // Two-byte floats with both overheads of the same type.
25362066929Swren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kF16, uint64_t, f16);
25462066929Swren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kBF16, uint64_t, bf16);
25562066929Swren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kF16, uint32_t, f16);
25662066929Swren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kBF16, uint32_t, bf16);
25762066929Swren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kF16, uint16_t, f16);
25862066929Swren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kBF16, uint16_t, bf16);
25962066929Swren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kF16, uint8_t, f16);
26062066929Swren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kBF16, uint8_t, bf16);
26162066929Swren romano 
26262066929Swren romano   // Integral matrices with both overheads of the same type.
26362066929Swren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
26462066929Swren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
26562066929Swren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
26662066929Swren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
26762066929Swren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI64, uint32_t, int64_t);
26862066929Swren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
26962066929Swren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
27062066929Swren romano   CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
27162066929Swren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI64, uint16_t, int64_t);
27262066929Swren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
27362066929Swren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
27462066929Swren romano   CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
27562066929Swren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI64, uint8_t, int64_t);
27662066929Swren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
27762066929Swren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
27862066929Swren romano   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
27962066929Swren romano 
28062066929Swren romano   // Complex matrices with wide overhead.
28162066929Swren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
28262066929Swren romano   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
28362066929Swren romano 
28462066929Swren romano   // Unsupported case (add above if needed).
2851c2456d6SAart Bik   fprintf(stderr, "unsupported combination of types: <P=%d, C=%d, V=%d>\n",
28684cd51bbSwren romano           static_cast<int>(posTp), static_cast<int>(crdTp),
28762066929Swren romano           static_cast<int>(valTp));
2881c2456d6SAart Bik   exit(1);
28962066929Swren romano }
29062066929Swren romano #undef CASE
29162066929Swren romano #undef CASE_SECSAME
29262066929Swren romano 
29362066929Swren romano #define IMPL_SPARSEVALUES(VNAME, V)                                            \
29462066929Swren romano   void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref,          \
29562066929Swren romano                                         void *tensor) {                        \
29662066929Swren romano     assert(ref &&tensor);                                                      \
29762066929Swren romano     std::vector<V> *v;                                                         \
29862066929Swren romano     static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v);             \
299a3e48883Swren romano     assert(v);                                                                 \
3002af2e4dbSwren romano     aliasIntoMemref(v->size(), v->data(), *ref);                               \
30162066929Swren romano   }
30262066929Swren romano MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES)
30362066929Swren romano #undef IMPL_SPARSEVALUES
30462066929Swren romano 
30562066929Swren romano #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
30662066929Swren romano   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
30784cd51bbSwren romano                            index_type lvl) {                                   \
30862066929Swren romano     assert(ref &&tensor);                                                      \
30962066929Swren romano     std::vector<TYPE> *v;                                                      \
31084cd51bbSwren romano     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, lvl);              \
311a3e48883Swren romano     assert(v);                                                                 \
3122af2e4dbSwren romano     aliasIntoMemref(v->size(), v->data(), *ref);                               \
31362066929Swren romano   }
314*dc4cfdbbSAart Bik 
31584cd51bbSwren romano #define IMPL_SPARSEPOSITIONS(PNAME, P)                                         \
31684cd51bbSwren romano   IMPL_GETOVERHEAD(sparsePositions##PNAME, P, getPositions)
MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS)31784cd51bbSwren romano MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS)
31884cd51bbSwren romano #undef IMPL_SPARSEPOSITIONS
31962066929Swren romano 
32084cd51bbSwren romano #define IMPL_SPARSECOORDINATES(CNAME, C)                                       \
32184cd51bbSwren romano   IMPL_GETOVERHEAD(sparseCoordinates##CNAME, C, getCoordinates)
32284cd51bbSwren romano MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
32384cd51bbSwren romano #undef IMPL_SPARSECOORDINATES
324*dc4cfdbbSAart Bik 
325*dc4cfdbbSAart Bik #define IMPL_SPARSECOORDINATESBUFFER(CNAME, C)                                 \
326*dc4cfdbbSAart Bik   IMPL_GETOVERHEAD(sparseCoordinatesBuffer##CNAME, C, getCoordinatesBuffer)
327*dc4cfdbbSAart Bik MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATESBUFFER)
328*dc4cfdbbSAart Bik #undef IMPL_SPARSECOORDINATESBUFFER
329*dc4cfdbbSAart Bik 
33062066929Swren romano #undef IMPL_GETOVERHEAD
33162066929Swren romano 
33262066929Swren romano #define IMPL_LEXINSERT(VNAME, V)                                               \
33384cd51bbSwren romano   void _mlir_ciface_lexInsert##VNAME(                                          \
33484cd51bbSwren romano       void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef,                 \
33562066929Swren romano       StridedMemRefType<V, 0> *vref) {                                         \
33684cd51bbSwren romano     assert(t &&vref);                                                          \
33784cd51bbSwren romano     auto &tensor = *static_cast<SparseTensorStorageBase *>(t);                 \
33884cd51bbSwren romano     ASSERT_NO_STRIDE(lvlCoordsRef);                                            \
33984cd51bbSwren romano     index_type *lvlCoords = MEMREF_GET_PAYLOAD(lvlCoordsRef);                  \
34084cd51bbSwren romano     assert(lvlCoords);                                                         \
341a3e48883Swren romano     V *value = MEMREF_GET_PAYLOAD(vref);                                       \
34284cd51bbSwren romano     tensor.lexInsert(lvlCoords, *value);                                       \
34362066929Swren romano   }
34462066929Swren romano MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
34562066929Swren romano #undef IMPL_LEXINSERT
34662066929Swren romano 
34762066929Swren romano #define IMPL_EXPINSERT(VNAME, V)                                               \
34862066929Swren romano   void _mlir_ciface_expInsert##VNAME(                                          \
34984cd51bbSwren romano       void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef,                 \
35062066929Swren romano       StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
35162066929Swren romano       StridedMemRefType<index_type, 1> *aref, index_type count) {              \
35284cd51bbSwren romano     assert(t);                                                                 \
35384cd51bbSwren romano     auto &tensor = *static_cast<SparseTensorStorageBase *>(t);                 \
35484cd51bbSwren romano     ASSERT_NO_STRIDE(lvlCoordsRef);                                            \
355a3e48883Swren romano     ASSERT_NO_STRIDE(vref);                                                    \
356a3e48883Swren romano     ASSERT_NO_STRIDE(fref);                                                    \
357a3e48883Swren romano     ASSERT_NO_STRIDE(aref);                                                    \
358a3e48883Swren romano     ASSERT_USIZE_EQ(vref, MEMREF_GET_USIZE(fref));                             \
35984cd51bbSwren romano     index_type *lvlCoords = MEMREF_GET_PAYLOAD(lvlCoordsRef);                  \
360a3e48883Swren romano     V *values = MEMREF_GET_PAYLOAD(vref);                                      \
361a3e48883Swren romano     bool *filled = MEMREF_GET_PAYLOAD(fref);                                   \
362a3e48883Swren romano     index_type *added = MEMREF_GET_PAYLOAD(aref);                              \
363ab6334ddSAart Bik     uint64_t expsz = vref->sizes[0];                                           \
364ab6334ddSAart Bik     tensor.expInsert(lvlCoords, values, filled, added, count, expsz);          \
36562066929Swren romano   }
36662066929Swren romano MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)
36762066929Swren romano #undef IMPL_EXPINSERT
36862066929Swren romano 
3692af2e4dbSwren romano void *_mlir_ciface_createCheckedSparseTensorReader(
3702af2e4dbSwren romano     char *filename, StridedMemRefType<index_type, 1> *dimShapeRef,
3712af2e4dbSwren romano     PrimaryType valTp) {
3722af2e4dbSwren romano   ASSERT_NO_STRIDE(dimShapeRef);
3732af2e4dbSwren romano   const uint64_t dimRank = MEMREF_GET_USIZE(dimShapeRef);
3742af2e4dbSwren romano   const index_type *dimShape = MEMREF_GET_PAYLOAD(dimShapeRef);
3752af2e4dbSwren romano   auto *reader = SparseTensorReader::create(filename, dimRank, dimShape, valTp);
3762af2e4dbSwren romano   return static_cast<void *>(reader);
3772af2e4dbSwren romano }
3782af2e4dbSwren romano 
_mlir_ciface_getSparseTensorReaderDimSizes(StridedMemRefType<index_type,1> * out,void * p)3792af2e4dbSwren romano void _mlir_ciface_getSparseTensorReaderDimSizes(
3802af2e4dbSwren romano     StridedMemRefType<index_type, 1> *out, void *p) {
3812af2e4dbSwren romano   assert(out && p);
3822af2e4dbSwren romano   SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
3832af2e4dbSwren romano   auto *dimSizes = const_cast<uint64_t *>(reader.getDimSizes());
3842af2e4dbSwren romano   aliasIntoMemref(reader.getRank(), dimSizes, *out);
385b32831f4Swren romano }
386b32831f4Swren romano 
38727ea470fSbixia1 #define IMPL_GETNEXT(VNAME, V, CNAME, C)                                       \
388b86d3cbcSAart Bik   bool _mlir_ciface_getSparseTensorReaderReadToBuffers##CNAME##VNAME(          \
38927ea470fSbixia1       void *p, StridedMemRefType<index_type, 1> *dim2lvlRef,                   \
390d3af6535SAart Bik       StridedMemRefType<index_type, 1> *lvl2dimRef,                            \
39127ea470fSbixia1       StridedMemRefType<C, 1> *cref, StridedMemRefType<V, 1> *vref) {          \
39227ea470fSbixia1     assert(p);                                                                 \
39327ea470fSbixia1     auto &reader = *static_cast<SparseTensorReader *>(p);                      \
394d3af6535SAart Bik     ASSERT_NO_STRIDE(dim2lvlRef);                                              \
395d3af6535SAart Bik     ASSERT_NO_STRIDE(lvl2dimRef);                                              \
39627ea470fSbixia1     ASSERT_NO_STRIDE(cref);                                                    \
39727ea470fSbixia1     ASSERT_NO_STRIDE(vref);                                                    \
398d3af6535SAart Bik     const uint64_t dimRank = reader.getRank();                                 \
399306f4c30SAart Bik     const uint64_t lvlRank = MEMREF_GET_USIZE(dim2lvlRef);                     \
40027ea470fSbixia1     const uint64_t cSize = MEMREF_GET_USIZE(cref);                             \
40127ea470fSbixia1     const uint64_t vSize = MEMREF_GET_USIZE(vref);                             \
402306f4c30SAart Bik     ASSERT_USIZE_EQ(lvl2dimRef, dimRank);                                      \
403f8ce460eSAart Bik     assert(cSize >= lvlRank * reader.getNSE());                                \
404f8ce460eSAart Bik     assert(vSize >= reader.getNSE());                                          \
405d3af6535SAart Bik     (void)dimRank;                                                             \
40627ea470fSbixia1     (void)cSize;                                                               \
40727ea470fSbixia1     (void)vSize;                                                               \
408d3af6535SAart Bik     index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);                      \
409d3af6535SAart Bik     index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);                      \
41027ea470fSbixia1     C *lvlCoordinates = MEMREF_GET_PAYLOAD(cref);                              \
41127ea470fSbixia1     V *values = MEMREF_GET_PAYLOAD(vref);                                      \
412d3af6535SAart Bik     return reader.readToBuffers<C, V>(lvlRank, dim2lvl, lvl2dim,               \
413d3af6535SAart Bik                                       lvlCoordinates, values);                 \
41427ea470fSbixia1   }
MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)41527ea470fSbixia1 MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)
41627ea470fSbixia1 #undef IMPL_GETNEXT
41727ea470fSbixia1 
418b32831f4Swren romano void _mlir_ciface_outSparseTensorWriterMetaData(
41984cd51bbSwren romano     void *p, index_type dimRank, index_type nse,
42084cd51bbSwren romano     StridedMemRefType<index_type, 1> *dimSizesRef) {
421a3e48883Swren romano   assert(p);
42284cd51bbSwren romano   ASSERT_NO_STRIDE(dimSizesRef);
42384cd51bbSwren romano   assert(dimRank != 0);
42484cd51bbSwren romano   index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef);
42548962383SAart Bik   std::ostream &file = *static_cast<std::ostream *>(p);
426ac8b53fcSAdrian Kuegel   file << dimRank << " " << nse << '\n';
42748962383SAart Bik   for (index_type d = 0; d < dimRank - 1; d++)
42884cd51bbSwren romano     file << dimSizes[d] << " ";
429ac8b53fcSAdrian Kuegel   file << dimSizes[dimRank - 1] << '\n';
430b32831f4Swren romano }
431b32831f4Swren romano 
432b32831f4Swren romano #define IMPL_OUTNEXT(VNAME, V)                                                 \
433b32831f4Swren romano   void _mlir_ciface_outSparseTensorWriterNext##VNAME(                          \
43484cd51bbSwren romano       void *p, index_type dimRank,                                             \
43584cd51bbSwren romano       StridedMemRefType<index_type, 1> *dimCoordsRef,                          \
436b32831f4Swren romano       StridedMemRefType<V, 0> *vref) {                                         \
437a3e48883Swren romano     assert(p &&vref);                                                          \
43884cd51bbSwren romano     ASSERT_NO_STRIDE(dimCoordsRef);                                            \
43984cd51bbSwren romano     const index_type *dimCoords = MEMREF_GET_PAYLOAD(dimCoordsRef);            \
44048962383SAart Bik     std::ostream &file = *static_cast<std::ostream *>(p);                      \
44148962383SAart Bik     for (index_type d = 0; d < dimRank; d++)                                   \
44284cd51bbSwren romano       file << (dimCoords[d] + 1) << " ";                                       \
443a3e48883Swren romano     V *value = MEMREF_GET_PAYLOAD(vref);                                       \
444ac8b53fcSAdrian Kuegel     file << *value << '\n';                                                    \
445b32831f4Swren romano   }
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT)446b32831f4Swren romano MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT)
447b32831f4Swren romano #undef IMPL_OUTNEXT
448b32831f4Swren romano 
44962066929Swren romano //===----------------------------------------------------------------------===//
45062066929Swren romano //
45162066929Swren romano // Public functions which accept only C-style data structures to interact
45262066929Swren romano // with sparse tensors (which are only visible as opaque pointers externally).
45362066929Swren romano //
45462066929Swren romano //===----------------------------------------------------------------------===//
45562066929Swren romano 
45686f91e45Swren romano index_type sparseLvlSize(void *tensor, index_type l) {
45786f91e45Swren romano   return static_cast<SparseTensorStorageBase *>(tensor)->getLvlSize(l);
45886f91e45Swren romano }
45986f91e45Swren romano 
sparseDimSize(void * tensor,index_type d)46086f91e45Swren romano index_type sparseDimSize(void *tensor, index_type d) {
46186f91e45Swren romano   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
46262066929Swren romano }
46362066929Swren romano 
endLexInsert(void * tensor)4642045cca0SAart Bik void endLexInsert(void *tensor) {
4652045cca0SAart Bik   return static_cast<SparseTensorStorageBase *>(tensor)->endLexInsert();
46662066929Swren romano }
46762066929Swren romano 
delSparseTensor(void * tensor)46862066929Swren romano void delSparseTensor(void *tensor) {
46962066929Swren romano   delete static_cast<SparseTensorStorageBase *>(tensor);
47062066929Swren romano }
47162066929Swren romano 
getTensorFilename(index_type id)47262066929Swren romano char *getTensorFilename(index_type id) {
473bf4480d9SMehdi Amini   constexpr size_t bufSize = 80;
474bf4480d9SMehdi Amini   char var[bufSize];
475bf4480d9SMehdi Amini   snprintf(var, bufSize, "TENSOR%" PRIu64, id);
47662066929Swren romano   char *env = getenv(var);
4771c2456d6SAart Bik   if (!env) {
4781c2456d6SAart Bik     fprintf(stderr, "Environment variable %s is not set\n", var);
4791c2456d6SAart Bik     exit(1);
4801c2456d6SAart Bik   }
48162066929Swren romano   return env;
48262066929Swren romano }
48362066929Swren romano 
getSparseTensorReaderNSE(void * p)48484cd51bbSwren romano index_type getSparseTensorReaderNSE(void *p) {
48584cd51bbSwren romano   return static_cast<SparseTensorReader *>(p)->getNSE();
486f2b73f51Sbixia1 }
487f2b73f51Sbixia1 
delSparseTensorReader(void * p)488f2b73f51Sbixia1 void delSparseTensorReader(void *p) {
489f2b73f51Sbixia1   delete static_cast<SparseTensorReader *>(p);
490f2b73f51Sbixia1 }
491f2b73f51Sbixia1 
createSparseTensorWriter(char * filename)492f2b73f51Sbixia1 void *createSparseTensorWriter(char *filename) {
49348962383SAart Bik   std::ostream *file =
494f2b73f51Sbixia1       (filename[0] == 0) ? &std::cout : new std::ofstream(filename);
495f2b73f51Sbixia1   *file << "# extended FROSTT format\n";
496f2b73f51Sbixia1   return static_cast<void *>(file);
497f2b73f51Sbixia1 }
498f2b73f51Sbixia1 
delSparseTensorWriter(void * p)499f2b73f51Sbixia1 void delSparseTensorWriter(void *p) {
50048962383SAart Bik   std::ostream *file = static_cast<std::ostream *>(p);
501f2b73f51Sbixia1   file->flush();
502f2b73f51Sbixia1   assert(file->good());
503f2b73f51Sbixia1   if (file != &std::cout)
504f2b73f51Sbixia1     delete file;
505f2b73f51Sbixia1 }
506f2b73f51Sbixia1 
507af5c3079SStella Stamenova } // extern "C"
5086c22dad9Swren romano 
509a3e48883Swren romano #undef MEMREF_GET_PAYLOAD
510a3e48883Swren romano #undef ASSERT_USIZE_EQ
511a3e48883Swren romano #undef MEMREF_GET_USIZE
512a3e48883Swren romano #undef ASSERT_NO_STRIDE
513a3e48883Swren romano 
51462066929Swren romano #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
515