xref: /llvm-project/mlir/lib/ExecutionEngine/SparseTensor/File.cpp (revision 2af2e4dbb790daafd3cbbf6189a7a27145cf4c12)
1 //===- File.cpp - Parsing sparse tensors from files -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements parsing and printing of files in one of the
10 // following external formats:
11 //
12 // (1) Matrix Market Exchange (MME): *.mtx
13 //     https://math.nist.gov/MatrixMarket/formats.html
14 //
15 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
16 //     http://frostt.io/tensors/file-formats.html
17 //
18 // This file is part of the lightweight runtime support library for sparse
19 // tensor manipulations.  The functionality of the support library is meant
20 // to simplify benchmarking, testing, and debugging MLIR code operating on
21 // sparse tensors.  However, the provided functionality is **not** part of
22 // core MLIR itself.
23 //
24 //===----------------------------------------------------------------------===//
25 
26 #include "mlir/ExecutionEngine/SparseTensor/File.h"
27 
28 #include <cctype>
29 #include <cstring>
30 
31 using namespace mlir::sparse_tensor;
32 
33 /// Opens the file for reading.
34 void SparseTensorReader::openFile() {
35   if (file)
36     MLIR_SPARSETENSOR_FATAL("Already opened file %s\n", filename);
37   file = fopen(filename, "r");
38   if (!file)
39     MLIR_SPARSETENSOR_FATAL("Cannot find file %s\n", filename);
40 }
41 
42 /// Closes the file.
43 void SparseTensorReader::closeFile() {
44   if (file) {
45     fclose(file);
46     file = nullptr;
47   }
48 }
49 
50 /// Attempts to read a line from the file.
51 void SparseTensorReader::readLine() {
52   if (!fgets(line, kColWidth, file))
53     MLIR_SPARSETENSOR_FATAL("Cannot read next line of %s\n", filename);
54 }
55 
56 char *SparseTensorReader::readCOOIndices(uint64_t *indices) {
57   readLine();
58   // Local variable for tracking the parser's position in the `line` buffer.
59   char *linePtr = line;
60   for (uint64_t rank = getRank(), r = 0; r < rank; ++r) {
61     // Parse the 1-based index.
62     uint64_t idx = strtoul(linePtr, &linePtr, 10);
63     // Store the 0-based index.
64     indices[r] = idx - 1;
65   }
66   return linePtr;
67 }
68 
69 /// Reads and parses the file's header.
70 void SparseTensorReader::readHeader() {
71   assert(file && "Attempt to readHeader() before openFile()");
72   if (strstr(filename, ".mtx"))
73     readMMEHeader();
74   else if (strstr(filename, ".tns"))
75     readExtFROSTTHeader();
76   else
77     MLIR_SPARSETENSOR_FATAL("Unknown format %s\n", filename);
78   assert(isValid() && "Failed to read the header");
79 }
80 
81 /// Asserts the shape subsumes the actual dimension sizes.  Is only
82 /// valid after parsing the header.
83 void SparseTensorReader::assertMatchesShape(uint64_t rank,
84                                             const uint64_t *shape) const {
85   assert(rank == getRank() && "Rank mismatch");
86   for (uint64_t r = 0; r < rank; ++r)
87     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
88            "Dimension size mismatch");
89 }
90 
91 bool SparseTensorReader::canReadAs(PrimaryType valTy) const {
92   switch (valueKind_) {
93   case ValueKind::kInvalid:
94     assert(false && "Must readHeader() before calling canReadAs()");
95     return false; // In case assertions are disabled.
96   case ValueKind::kPattern:
97     return true;
98   case ValueKind::kInteger:
99     // When the file is specified to store integer values, we still
100     // allow implicitly converting those to floating primary-types.
101     return isRealPrimaryType(valTy);
102   case ValueKind::kReal:
103     // When the file is specified to store real/floating values, then
104     // we disallow implicit conversion to integer primary-types.
105     return isFloatingPrimaryType(valTy);
106   case ValueKind::kComplex:
107     // When the file is specified to store complex values, then we
108     // require a complex primary-type.
109     return isComplexPrimaryType(valTy);
110   case ValueKind::kUndefined:
111     // The "extended" FROSTT format doesn't specify a ValueKind.
112     // So we allow implicitly converting the stored values to both
113     // integer and floating primary-types.
114     return isRealPrimaryType(valTy);
115   }
116   MLIR_SPARSETENSOR_FATAL("Unknown ValueKind: %d\n",
117                           static_cast<uint8_t>(valueKind_));
118 }
119 
120 /// Helper to convert C-style strings (i.e., '\0' terminated) to lower case.
121 static inline void toLower(char *token) {
122   for (char *c = token; *c; ++c)
123     *c = tolower(*c);
124 }
125 
126 /// Idiomatic name for checking string equality.
127 static inline bool streq(const char *lhs, const char *rhs) {
128   return strcmp(lhs, rhs) == 0;
129 }
130 
131 /// Idiomatic name for checking string inequality.
132 static inline bool strne(const char *lhs, const char *rhs) {
133   return strcmp(lhs, rhs); // aka `!= 0`
134 }
135 
136 /// Read the MME header of a general sparse matrix of type real.
137 void SparseTensorReader::readMMEHeader() {
138   char header[64];
139   char object[64];
140   char format[64];
141   char field[64];
142   char symmetry[64];
143   // Read header line.
144   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
145              symmetry) != 5)
146     MLIR_SPARSETENSOR_FATAL("Corrupt header in %s\n", filename);
147   // Convert all to lowercase up front (to avoid accidental redundancy).
148   toLower(header);
149   toLower(object);
150   toLower(format);
151   toLower(field);
152   toLower(symmetry);
153   // Process `field`, which specify pattern or the data type of the values.
154   if (streq(field, "pattern"))
155     valueKind_ = ValueKind::kPattern;
156   else if (streq(field, "real"))
157     valueKind_ = ValueKind::kReal;
158   else if (streq(field, "integer"))
159     valueKind_ = ValueKind::kInteger;
160   else if (streq(field, "complex"))
161     valueKind_ = ValueKind::kComplex;
162   else
163     MLIR_SPARSETENSOR_FATAL("Unexpected header field value in %s\n", filename);
164   // Set properties.
165   isSymmetric_ = streq(symmetry, "symmetric");
166   // Make sure this is a general sparse matrix.
167   if (strne(header, "%%matrixmarket") || strne(object, "matrix") ||
168       strne(format, "coordinate") ||
169       (strne(symmetry, "general") && !isSymmetric_))
170     MLIR_SPARSETENSOR_FATAL("Cannot find a general sparse matrix in %s\n",
171                             filename);
172   // Skip comments.
173   while (true) {
174     readLine();
175     if (line[0] != '%')
176       break;
177   }
178   // Next line contains M N NNZ.
179   idata[0] = 2; // rank
180   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
181              idata + 1) != 3)
182     MLIR_SPARSETENSOR_FATAL("Cannot find size in %s\n", filename);
183 }
184 
185 /// Read the "extended" FROSTT header. Although not part of the documented
186 /// format, we assume that the file starts with optional comments followed
187 /// by two lines that define the rank, the number of nonzeros, and the
188 /// dimensions sizes (one per rank) of the sparse tensor.
189 void SparseTensorReader::readExtFROSTTHeader() {
190   // Skip comments.
191   while (true) {
192     readLine();
193     if (line[0] != '#')
194       break;
195   }
196   // Next line contains RANK and NNZ.
197   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
198     MLIR_SPARSETENSOR_FATAL("Cannot find metadata in %s\n", filename);
199   // Followed by a line with the dimension sizes (one per rank).
200   for (uint64_t r = 0; r < idata[0]; ++r)
201     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
202       MLIR_SPARSETENSOR_FATAL("Cannot find dimension size %s\n", filename);
203   readLine(); // end of line
204   // The FROSTT format does not define the data type of the nonzero elements.
205   valueKind_ = ValueKind::kUndefined;
206 }
207