1427f120fSAart Bik //===- File.cpp - Reading/writing sparse tensors from/to files ------------===//
20fca5c5fSwren romano //
30fca5c5fSwren romano // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40fca5c5fSwren romano // See https://llvm.org/LICENSE.txt for license information.
50fca5c5fSwren romano // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60fca5c5fSwren romano //
70fca5c5fSwren romano //===----------------------------------------------------------------------===//
80fca5c5fSwren romano //
9427f120fSAart Bik // This file implements reading and writing sparse tensor files.
100fca5c5fSwren romano //
110fca5c5fSwren romano //===----------------------------------------------------------------------===//
120fca5c5fSwren romano
130fca5c5fSwren romano #include "mlir/ExecutionEngine/SparseTensor/File.h"
140fca5c5fSwren romano
150fca5c5fSwren romano #include <cctype>
160fca5c5fSwren romano #include <cstring>
170fca5c5fSwren romano
180fca5c5fSwren romano using namespace mlir::sparse_tensor;
190fca5c5fSwren romano
200fca5c5fSwren romano /// Opens the file for reading.
openFile()21461c461aSbixia1 void SparseTensorReader::openFile() {
22*1c2456d6SAart Bik if (file) {
23*1c2456d6SAart Bik fprintf(stderr, "Already opened file %s\n", filename);
24*1c2456d6SAart Bik exit(1);
25*1c2456d6SAart Bik }
260fca5c5fSwren romano file = fopen(filename, "r");
27*1c2456d6SAart Bik if (!file) {
28*1c2456d6SAart Bik fprintf(stderr, "Cannot find file %s\n", filename);
29*1c2456d6SAart Bik exit(1);
30*1c2456d6SAart Bik }
310fca5c5fSwren romano }
320fca5c5fSwren romano
330fca5c5fSwren romano /// Closes the file.
closeFile()34461c461aSbixia1 void SparseTensorReader::closeFile() {
350fca5c5fSwren romano if (file) {
360fca5c5fSwren romano fclose(file);
370fca5c5fSwren romano file = nullptr;
380fca5c5fSwren romano }
390fca5c5fSwren romano }
400fca5c5fSwren romano
410fca5c5fSwren romano /// Attempts to read a line from the file.
readLine()42c518745bSwren romano void SparseTensorReader::readLine() {
43*1c2456d6SAart Bik if (!fgets(line, kColWidth, file)) {
44*1c2456d6SAart Bik fprintf(stderr, "Cannot read next line of %s\n", filename);
45*1c2456d6SAart Bik exit(1);
46*1c2456d6SAart Bik }
470fca5c5fSwren romano }
480fca5c5fSwren romano
490fca5c5fSwren romano /// Reads and parses the file's header.
readHeader()50461c461aSbixia1 void SparseTensorReader::readHeader() {
510fca5c5fSwren romano assert(file && "Attempt to readHeader() before openFile()");
52*1c2456d6SAart Bik if (strstr(filename, ".mtx")) {
530fca5c5fSwren romano readMMEHeader();
54*1c2456d6SAart Bik } else if (strstr(filename, ".tns")) {
550fca5c5fSwren romano readExtFROSTTHeader();
56*1c2456d6SAart Bik } else {
57*1c2456d6SAart Bik fprintf(stderr, "Unknown format %s\n", filename);
58*1c2456d6SAart Bik exit(1);
59*1c2456d6SAart Bik }
600fca5c5fSwren romano assert(isValid() && "Failed to read the header");
610fca5c5fSwren romano }
620fca5c5fSwren romano
630fca5c5fSwren romano /// Asserts the shape subsumes the actual dimension sizes. Is only
640fca5c5fSwren romano /// valid after parsing the header.
assertMatchesShape(uint64_t rank,const uint64_t * shape) const65461c461aSbixia1 void SparseTensorReader::assertMatchesShape(uint64_t rank,
660fca5c5fSwren romano const uint64_t *shape) const {
670fca5c5fSwren romano assert(rank == getRank() && "Rank mismatch");
68*1c2456d6SAart Bik for (uint64_t r = 0; r < rank; r++)
690fca5c5fSwren romano assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
700fca5c5fSwren romano "Dimension size mismatch");
710fca5c5fSwren romano }
720fca5c5fSwren romano
canReadAs(PrimaryType valTy) const73461c461aSbixia1 bool SparseTensorReader::canReadAs(PrimaryType valTy) const {
74c8177f84Swren romano switch (valueKind_) {
75c8177f84Swren romano case ValueKind::kInvalid:
76c8177f84Swren romano assert(false && "Must readHeader() before calling canReadAs()");
77c8177f84Swren romano return false; // In case assertions are disabled.
78c8177f84Swren romano case ValueKind::kPattern:
79c8177f84Swren romano return true;
80c8177f84Swren romano case ValueKind::kInteger:
81c8177f84Swren romano // When the file is specified to store integer values, we still
82c8177f84Swren romano // allow implicitly converting those to floating primary-types.
83c8177f84Swren romano return isRealPrimaryType(valTy);
84c8177f84Swren romano case ValueKind::kReal:
85c8177f84Swren romano // When the file is specified to store real/floating values, then
86c8177f84Swren romano // we disallow implicit conversion to integer primary-types.
87c8177f84Swren romano return isFloatingPrimaryType(valTy);
88c8177f84Swren romano case ValueKind::kComplex:
89c8177f84Swren romano // When the file is specified to store complex values, then we
90c8177f84Swren romano // require a complex primary-type.
91c8177f84Swren romano return isComplexPrimaryType(valTy);
92c8177f84Swren romano case ValueKind::kUndefined:
93c8177f84Swren romano // The "extended" FROSTT format doesn't specify a ValueKind.
94c8177f84Swren romano // So we allow implicitly converting the stored values to both
95c8177f84Swren romano // integer and floating primary-types.
96c8177f84Swren romano return isRealPrimaryType(valTy);
97c8177f84Swren romano }
98*1c2456d6SAart Bik fprintf(stderr, "Unknown ValueKind: %d\n", static_cast<uint8_t>(valueKind_));
99*1c2456d6SAart Bik return false;
100c8177f84Swren romano }
101c8177f84Swren romano
102164b66f7Swren romano /// Helper to convert C-style strings (i.e., '\0' terminated) to lower case.
toLower(char * token)10397bd83b5Swren romano static inline void toLower(char *token) {
104*1c2456d6SAart Bik for (char *c = token; *c; c++)
1050fca5c5fSwren romano *c = tolower(*c);
1060fca5c5fSwren romano }
1070fca5c5fSwren romano
1084792b8aeSwren romano /// Idiomatic name for checking string equality.
streq(const char * lhs,const char * rhs)1094792b8aeSwren romano static inline bool streq(const char *lhs, const char *rhs) {
1104792b8aeSwren romano return strcmp(lhs, rhs) == 0;
1114792b8aeSwren romano }
1124792b8aeSwren romano
1134792b8aeSwren romano /// Idiomatic name for checking string inequality.
strne(const char * lhs,const char * rhs)1144792b8aeSwren romano static inline bool strne(const char *lhs, const char *rhs) {
1154792b8aeSwren romano return strcmp(lhs, rhs); // aka `!= 0`
1164792b8aeSwren romano }
1174792b8aeSwren romano
1180fca5c5fSwren romano /// Read the MME header of a general sparse matrix of type real.
readMMEHeader()119461c461aSbixia1 void SparseTensorReader::readMMEHeader() {
1200fca5c5fSwren romano char header[64];
1210fca5c5fSwren romano char object[64];
1220fca5c5fSwren romano char format[64];
1230fca5c5fSwren romano char field[64];
1240fca5c5fSwren romano char symmetry[64];
1250fca5c5fSwren romano // Read header line.
1260fca5c5fSwren romano if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
127*1c2456d6SAart Bik symmetry) != 5) {
128*1c2456d6SAart Bik fprintf(stderr, "Corrupt header in %s\n", filename);
129*1c2456d6SAart Bik exit(1);
130*1c2456d6SAart Bik }
1314792b8aeSwren romano // Convert all to lowercase up front (to avoid accidental redundancy).
1324792b8aeSwren romano toLower(header);
1334792b8aeSwren romano toLower(object);
1344792b8aeSwren romano toLower(format);
1354792b8aeSwren romano toLower(field);
1364792b8aeSwren romano toLower(symmetry);
1370fca5c5fSwren romano // Process `field`, which specify pattern or the data type of the values.
138*1c2456d6SAart Bik if (streq(field, "pattern")) {
1390fca5c5fSwren romano valueKind_ = ValueKind::kPattern;
140*1c2456d6SAart Bik } else if (streq(field, "real")) {
1410fca5c5fSwren romano valueKind_ = ValueKind::kReal;
142*1c2456d6SAart Bik } else if (streq(field, "integer")) {
1430fca5c5fSwren romano valueKind_ = ValueKind::kInteger;
144*1c2456d6SAart Bik } else if (streq(field, "complex")) {
1450fca5c5fSwren romano valueKind_ = ValueKind::kComplex;
146*1c2456d6SAart Bik } else {
147*1c2456d6SAart Bik fprintf(stderr, "Unexpected header field value in %s\n", filename);
148*1c2456d6SAart Bik exit(1);
149*1c2456d6SAart Bik }
1500fca5c5fSwren romano // Set properties.
1514792b8aeSwren romano isSymmetric_ = streq(symmetry, "symmetric");
1520fca5c5fSwren romano // Make sure this is a general sparse matrix.
1534792b8aeSwren romano if (strne(header, "%%matrixmarket") || strne(object, "matrix") ||
1544792b8aeSwren romano strne(format, "coordinate") ||
155*1c2456d6SAart Bik (strne(symmetry, "general") && !isSymmetric_)) {
156*1c2456d6SAart Bik fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
157*1c2456d6SAart Bik exit(1);
158*1c2456d6SAart Bik }
1590fca5c5fSwren romano // Skip comments.
1600fca5c5fSwren romano while (true) {
1610fca5c5fSwren romano readLine();
1620fca5c5fSwren romano if (line[0] != '%')
1630fca5c5fSwren romano break;
1640fca5c5fSwren romano }
1650fca5c5fSwren romano // Next line contains M N NNZ.
1660fca5c5fSwren romano idata[0] = 2; // rank
1670fca5c5fSwren romano if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
168*1c2456d6SAart Bik idata + 1) != 3) {
169*1c2456d6SAart Bik fprintf(stderr, "Cannot find size in %s\n", filename);
170*1c2456d6SAart Bik exit(1);
171*1c2456d6SAart Bik }
1720fca5c5fSwren romano }
1730fca5c5fSwren romano
1740fca5c5fSwren romano /// Read the "extended" FROSTT header. Although not part of the documented
1750fca5c5fSwren romano /// format, we assume that the file starts with optional comments followed
1760fca5c5fSwren romano /// by two lines that define the rank, the number of nonzeros, and the
1770fca5c5fSwren romano /// dimensions sizes (one per rank) of the sparse tensor.
readExtFROSTTHeader()178461c461aSbixia1 void SparseTensorReader::readExtFROSTTHeader() {
1790fca5c5fSwren romano // Skip comments.
1800fca5c5fSwren romano while (true) {
1810fca5c5fSwren romano readLine();
1820fca5c5fSwren romano if (line[0] != '#')
1830fca5c5fSwren romano break;
1840fca5c5fSwren romano }
1850fca5c5fSwren romano // Next line contains RANK and NNZ.
186*1c2456d6SAart Bik if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
187*1c2456d6SAart Bik fprintf(stderr, "Cannot find metadata in %s\n", filename);
188*1c2456d6SAart Bik exit(1);
189*1c2456d6SAart Bik }
1900fca5c5fSwren romano // Followed by a line with the dimension sizes (one per rank).
191*1c2456d6SAart Bik for (uint64_t r = 0; r < idata[0]; r++) {
192*1c2456d6SAart Bik if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
193*1c2456d6SAart Bik fprintf(stderr, "Cannot find dimension size %s\n", filename);
194*1c2456d6SAart Bik exit(1);
195*1c2456d6SAart Bik }
196*1c2456d6SAart Bik }
1970fca5c5fSwren romano readLine(); // end of line
1980fca5c5fSwren romano // The FROSTT format does not define the data type of the nonzero elements.
1990fca5c5fSwren romano valueKind_ = ValueKind::kUndefined;
2000fca5c5fSwren romano }
201