xref: /llvm-project/mlir/lib/ExecutionEngine/SparseTensor/File.cpp (revision 427f120f60becd23d6e037ccd14104adde8a3af9)
1 //===- File.cpp - Reading/writing sparse tensors from/to files ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements reading and writing sparse tensor files.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/ExecutionEngine/SparseTensor/File.h"
14 
15 #include <cctype>
16 #include <cstring>
17 
18 using namespace mlir::sparse_tensor;
19 
20 /// Opens the file for reading.
21 void SparseTensorReader::openFile() {
22   if (file)
23     MLIR_SPARSETENSOR_FATAL("Already opened file %s\n", filename);
24   file = fopen(filename, "r");
25   if (!file)
26     MLIR_SPARSETENSOR_FATAL("Cannot find file %s\n", filename);
27 }
28 
29 /// Closes the file.
30 void SparseTensorReader::closeFile() {
31   if (file) {
32     fclose(file);
33     file = nullptr;
34   }
35 }
36 
37 /// Attempts to read a line from the file.
38 void SparseTensorReader::readLine() {
39   if (!fgets(line, kColWidth, file))
40     MLIR_SPARSETENSOR_FATAL("Cannot read next line of %s\n", filename);
41 }
42 
43 /// Reads and parses the file's header.
44 void SparseTensorReader::readHeader() {
45   assert(file && "Attempt to readHeader() before openFile()");
46   if (strstr(filename, ".mtx"))
47     readMMEHeader();
48   else if (strstr(filename, ".tns"))
49     readExtFROSTTHeader();
50   else
51     MLIR_SPARSETENSOR_FATAL("Unknown format %s\n", filename);
52   assert(isValid() && "Failed to read the header");
53 }
54 
55 /// Asserts the shape subsumes the actual dimension sizes.  Is only
56 /// valid after parsing the header.
57 void SparseTensorReader::assertMatchesShape(uint64_t rank,
58                                             const uint64_t *shape) const {
59   assert(rank == getRank() && "Rank mismatch");
60   for (uint64_t r = 0; r < rank; ++r)
61     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
62            "Dimension size mismatch");
63 }
64 
65 bool SparseTensorReader::canReadAs(PrimaryType valTy) const {
66   switch (valueKind_) {
67   case ValueKind::kInvalid:
68     assert(false && "Must readHeader() before calling canReadAs()");
69     return false; // In case assertions are disabled.
70   case ValueKind::kPattern:
71     return true;
72   case ValueKind::kInteger:
73     // When the file is specified to store integer values, we still
74     // allow implicitly converting those to floating primary-types.
75     return isRealPrimaryType(valTy);
76   case ValueKind::kReal:
77     // When the file is specified to store real/floating values, then
78     // we disallow implicit conversion to integer primary-types.
79     return isFloatingPrimaryType(valTy);
80   case ValueKind::kComplex:
81     // When the file is specified to store complex values, then we
82     // require a complex primary-type.
83     return isComplexPrimaryType(valTy);
84   case ValueKind::kUndefined:
85     // The "extended" FROSTT format doesn't specify a ValueKind.
86     // So we allow implicitly converting the stored values to both
87     // integer and floating primary-types.
88     return isRealPrimaryType(valTy);
89   }
90   MLIR_SPARSETENSOR_FATAL("Unknown ValueKind: %d\n",
91                           static_cast<uint8_t>(valueKind_));
92 }
93 
94 /// Helper to convert C-style strings (i.e., '\0' terminated) to lower case.
95 static inline void toLower(char *token) {
96   for (char *c = token; *c; ++c)
97     *c = tolower(*c);
98 }
99 
100 /// Idiomatic name for checking string equality.
101 static inline bool streq(const char *lhs, const char *rhs) {
102   return strcmp(lhs, rhs) == 0;
103 }
104 
105 /// Idiomatic name for checking string inequality.
106 static inline bool strne(const char *lhs, const char *rhs) {
107   return strcmp(lhs, rhs); // aka `!= 0`
108 }
109 
110 /// Read the MME header of a general sparse matrix of type real.
111 void SparseTensorReader::readMMEHeader() {
112   char header[64];
113   char object[64];
114   char format[64];
115   char field[64];
116   char symmetry[64];
117   // Read header line.
118   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
119              symmetry) != 5)
120     MLIR_SPARSETENSOR_FATAL("Corrupt header in %s\n", filename);
121   // Convert all to lowercase up front (to avoid accidental redundancy).
122   toLower(header);
123   toLower(object);
124   toLower(format);
125   toLower(field);
126   toLower(symmetry);
127   // Process `field`, which specify pattern or the data type of the values.
128   if (streq(field, "pattern"))
129     valueKind_ = ValueKind::kPattern;
130   else if (streq(field, "real"))
131     valueKind_ = ValueKind::kReal;
132   else if (streq(field, "integer"))
133     valueKind_ = ValueKind::kInteger;
134   else if (streq(field, "complex"))
135     valueKind_ = ValueKind::kComplex;
136   else
137     MLIR_SPARSETENSOR_FATAL("Unexpected header field value in %s\n", filename);
138   // Set properties.
139   isSymmetric_ = streq(symmetry, "symmetric");
140   // Make sure this is a general sparse matrix.
141   if (strne(header, "%%matrixmarket") || strne(object, "matrix") ||
142       strne(format, "coordinate") ||
143       (strne(symmetry, "general") && !isSymmetric_))
144     MLIR_SPARSETENSOR_FATAL("Cannot find a general sparse matrix in %s\n",
145                             filename);
146   // Skip comments.
147   while (true) {
148     readLine();
149     if (line[0] != '%')
150       break;
151   }
152   // Next line contains M N NNZ.
153   idata[0] = 2; // rank
154   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
155              idata + 1) != 3)
156     MLIR_SPARSETENSOR_FATAL("Cannot find size in %s\n", filename);
157 }
158 
159 /// Read the "extended" FROSTT header. Although not part of the documented
160 /// format, we assume that the file starts with optional comments followed
161 /// by two lines that define the rank, the number of nonzeros, and the
162 /// dimensions sizes (one per rank) of the sparse tensor.
163 void SparseTensorReader::readExtFROSTTHeader() {
164   // Skip comments.
165   while (true) {
166     readLine();
167     if (line[0] != '#')
168       break;
169   }
170   // Next line contains RANK and NNZ.
171   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
172     MLIR_SPARSETENSOR_FATAL("Cannot find metadata in %s\n", filename);
173   // Followed by a line with the dimension sizes (one per rank).
174   for (uint64_t r = 0; r < idata[0]; ++r)
175     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
176       MLIR_SPARSETENSOR_FATAL("Cannot find dimension size %s\n", filename);
177   readLine(); // end of line
178   // The FROSTT format does not define the data type of the nonzero elements.
179   valueKind_ = ValueKind::kUndefined;
180 }
181