xref: /llvm-project/mlir/lib/ExecutionEngine/SparseTensor/File.cpp (revision 1c2456d6593cea317a00627889b5c35766a732e0)
1 //===- File.cpp - Reading/writing sparse tensors from/to files ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements reading and writing sparse tensor files.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/ExecutionEngine/SparseTensor/File.h"
14 
15 #include <cctype>
16 #include <cstring>
17 
18 using namespace mlir::sparse_tensor;
19 
20 /// Opens the file for reading.
openFile()21 void SparseTensorReader::openFile() {
22   if (file) {
23     fprintf(stderr, "Already opened file %s\n", filename);
24     exit(1);
25   }
26   file = fopen(filename, "r");
27   if (!file) {
28     fprintf(stderr, "Cannot find file %s\n", filename);
29     exit(1);
30   }
31 }
32 
33 /// Closes the file.
closeFile()34 void SparseTensorReader::closeFile() {
35   if (file) {
36     fclose(file);
37     file = nullptr;
38   }
39 }
40 
41 /// Attempts to read a line from the file.
readLine()42 void SparseTensorReader::readLine() {
43   if (!fgets(line, kColWidth, file)) {
44     fprintf(stderr, "Cannot read next line of %s\n", filename);
45     exit(1);
46   }
47 }
48 
49 /// Reads and parses the file's header.
readHeader()50 void SparseTensorReader::readHeader() {
51   assert(file && "Attempt to readHeader() before openFile()");
52   if (strstr(filename, ".mtx")) {
53     readMMEHeader();
54   } else if (strstr(filename, ".tns")) {
55     readExtFROSTTHeader();
56   } else {
57     fprintf(stderr, "Unknown format %s\n", filename);
58     exit(1);
59   }
60   assert(isValid() && "Failed to read the header");
61 }
62 
63 /// Asserts the shape subsumes the actual dimension sizes.  Is only
64 /// valid after parsing the header.
assertMatchesShape(uint64_t rank,const uint64_t * shape) const65 void SparseTensorReader::assertMatchesShape(uint64_t rank,
66                                             const uint64_t *shape) const {
67   assert(rank == getRank() && "Rank mismatch");
68   for (uint64_t r = 0; r < rank; r++)
69     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
70            "Dimension size mismatch");
71 }
72 
canReadAs(PrimaryType valTy) const73 bool SparseTensorReader::canReadAs(PrimaryType valTy) const {
74   switch (valueKind_) {
75   case ValueKind::kInvalid:
76     assert(false && "Must readHeader() before calling canReadAs()");
77     return false; // In case assertions are disabled.
78   case ValueKind::kPattern:
79     return true;
80   case ValueKind::kInteger:
81     // When the file is specified to store integer values, we still
82     // allow implicitly converting those to floating primary-types.
83     return isRealPrimaryType(valTy);
84   case ValueKind::kReal:
85     // When the file is specified to store real/floating values, then
86     // we disallow implicit conversion to integer primary-types.
87     return isFloatingPrimaryType(valTy);
88   case ValueKind::kComplex:
89     // When the file is specified to store complex values, then we
90     // require a complex primary-type.
91     return isComplexPrimaryType(valTy);
92   case ValueKind::kUndefined:
93     // The "extended" FROSTT format doesn't specify a ValueKind.
94     // So we allow implicitly converting the stored values to both
95     // integer and floating primary-types.
96     return isRealPrimaryType(valTy);
97   }
98   fprintf(stderr, "Unknown ValueKind: %d\n", static_cast<uint8_t>(valueKind_));
99   return false;
100 }
101 
102 /// Helper to convert C-style strings (i.e., '\0' terminated) to lower case.
toLower(char * token)103 static inline void toLower(char *token) {
104   for (char *c = token; *c; c++)
105     *c = tolower(*c);
106 }
107 
108 /// Idiomatic name for checking string equality.
streq(const char * lhs,const char * rhs)109 static inline bool streq(const char *lhs, const char *rhs) {
110   return strcmp(lhs, rhs) == 0;
111 }
112 
113 /// Idiomatic name for checking string inequality.
strne(const char * lhs,const char * rhs)114 static inline bool strne(const char *lhs, const char *rhs) {
115   return strcmp(lhs, rhs); // aka `!= 0`
116 }
117 
118 /// Read the MME header of a general sparse matrix of type real.
readMMEHeader()119 void SparseTensorReader::readMMEHeader() {
120   char header[64];
121   char object[64];
122   char format[64];
123   char field[64];
124   char symmetry[64];
125   // Read header line.
126   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
127              symmetry) != 5) {
128     fprintf(stderr, "Corrupt header in %s\n", filename);
129     exit(1);
130   }
131   // Convert all to lowercase up front (to avoid accidental redundancy).
132   toLower(header);
133   toLower(object);
134   toLower(format);
135   toLower(field);
136   toLower(symmetry);
137   // Process `field`, which specify pattern or the data type of the values.
138   if (streq(field, "pattern")) {
139     valueKind_ = ValueKind::kPattern;
140   } else if (streq(field, "real")) {
141     valueKind_ = ValueKind::kReal;
142   } else if (streq(field, "integer")) {
143     valueKind_ = ValueKind::kInteger;
144   } else if (streq(field, "complex")) {
145     valueKind_ = ValueKind::kComplex;
146   } else {
147     fprintf(stderr, "Unexpected header field value in %s\n", filename);
148     exit(1);
149   }
150   // Set properties.
151   isSymmetric_ = streq(symmetry, "symmetric");
152   // Make sure this is a general sparse matrix.
153   if (strne(header, "%%matrixmarket") || strne(object, "matrix") ||
154       strne(format, "coordinate") ||
155       (strne(symmetry, "general") && !isSymmetric_)) {
156     fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
157     exit(1);
158   }
159   // Skip comments.
160   while (true) {
161     readLine();
162     if (line[0] != '%')
163       break;
164   }
165   // Next line contains M N NNZ.
166   idata[0] = 2; // rank
167   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
168              idata + 1) != 3) {
169     fprintf(stderr, "Cannot find size in %s\n", filename);
170     exit(1);
171   }
172 }
173 
174 /// Read the "extended" FROSTT header. Although not part of the documented
175 /// format, we assume that the file starts with optional comments followed
176 /// by two lines that define the rank, the number of nonzeros, and the
177 /// dimensions sizes (one per rank) of the sparse tensor.
readExtFROSTTHeader()178 void SparseTensorReader::readExtFROSTTHeader() {
179   // Skip comments.
180   while (true) {
181     readLine();
182     if (line[0] != '#')
183       break;
184   }
185   // Next line contains RANK and NNZ.
186   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
187     fprintf(stderr, "Cannot find metadata in %s\n", filename);
188     exit(1);
189   }
190   // Followed by a line with the dimension sizes (one per rank).
191   for (uint64_t r = 0; r < idata[0]; r++) {
192     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
193       fprintf(stderr, "Cannot find dimension size %s\n", filename);
194       exit(1);
195     }
196   }
197   readLine(); // end of line
198   // The FROSTT format does not define the data type of the nonzero elements.
199   valueKind_ = ValueKind::kUndefined;
200 }
201