xref: /llvm-project/mlir/lib/ExecutionEngine/SparseTensor/File.cpp (revision 329f2f103af14b675daf9c3969c117dcfb785a8a)
1 //===- File.cpp - Parsing sparse tensors from files -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements parsing and printing of files in one of the
10 // following external formats:
11 //
12 // (1) Matrix Market Exchange (MME): *.mtx
13 //     https://math.nist.gov/MatrixMarket/formats.html
14 //
15 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
16 //     http://frostt.io/tensors/file-formats.html
17 //
18 // This file is part of the lightweight runtime support library for sparse
19 // tensor manipulations.  The functionality of the support library is meant
20 // to simplify benchmarking, testing, and debugging MLIR code operating on
21 // sparse tensors.  However, the provided functionality is **not** part of
22 // core MLIR itself.
23 //
24 //===----------------------------------------------------------------------===//
25 
26 #include "mlir/ExecutionEngine/SparseTensor/File.h"
27 
28 #include <cctype>
29 #include <cstring>
30 
31 using namespace mlir::sparse_tensor;
32 
33 /// Opens the file for reading.
34 void SparseTensorFile::openFile() {
35   if (file)
36     MLIR_SPARSETENSOR_FATAL("Already opened file %s\n", filename);
37   file = fopen(filename, "r");
38   if (!file)
39     MLIR_SPARSETENSOR_FATAL("Cannot find file %s\n", filename);
40 }
41 
42 /// Closes the file.
43 void SparseTensorFile::closeFile() {
44   if (file) {
45     fclose(file);
46     file = nullptr;
47   }
48 }
49 
50 // TODO(wrengr/bixia): figure out how to reorganize the element-parsing
51 // loop of `openSparseTensorCOO` into methods of this class, so we can
52 // avoid leaking access to the `line` pointer (both for general hygiene
53 // and because we can't mark it const due to the second argument of
54 // `strtoul`/`strtoud` being `char * *restrict` rather than
55 // `char const* *restrict`).
56 //
57 /// Attempts to read a line from the file.
58 char *SparseTensorFile::readLine() {
59   if (fgets(line, kColWidth, file))
60     return line;
61   MLIR_SPARSETENSOR_FATAL("Cannot read next line of %s\n", filename);
62 }
63 
64 /// Reads and parses the file's header.
65 void SparseTensorFile::readHeader() {
66   assert(file && "Attempt to readHeader() before openFile()");
67   if (strstr(filename, ".mtx"))
68     readMMEHeader();
69   else if (strstr(filename, ".tns"))
70     readExtFROSTTHeader();
71   else
72     MLIR_SPARSETENSOR_FATAL("Unknown format %s\n", filename);
73   assert(isValid() && "Failed to read the header");
74 }
75 
76 /// Asserts the shape subsumes the actual dimension sizes.  Is only
77 /// valid after parsing the header.
78 void SparseTensorFile::assertMatchesShape(uint64_t rank,
79                                           const uint64_t *shape) const {
80   assert(rank == getRank() && "Rank mismatch");
81   for (uint64_t r = 0; r < rank; ++r)
82     assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
83            "Dimension size mismatch");
84 }
85 
86 /// Helper to convert string to lower case.
87 static inline char *toLower(char *token) {
88   for (char *c = token; *c; ++c)
89     *c = tolower(*c);
90   return token;
91 }
92 
93 /// Read the MME header of a general sparse matrix of type real.
94 void SparseTensorFile::readMMEHeader() {
95   char header[64];
96   char object[64];
97   char format[64];
98   char field[64];
99   char symmetry[64];
100   // Read header line.
101   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
102              symmetry) != 5)
103     MLIR_SPARSETENSOR_FATAL("Corrupt header in %s\n", filename);
104   // Process `field`, which specify pattern or the data type of the values.
105   if (strcmp(toLower(field), "pattern") == 0)
106     valueKind_ = ValueKind::kPattern;
107   else if (strcmp(toLower(field), "real") == 0)
108     valueKind_ = ValueKind::kReal;
109   else if (strcmp(toLower(field), "integer") == 0)
110     valueKind_ = ValueKind::kInteger;
111   else if (strcmp(toLower(field), "complex") == 0)
112     valueKind_ = ValueKind::kComplex;
113   else
114     MLIR_SPARSETENSOR_FATAL("Unexpected header field value in %s\n", filename);
115 
116   // Set properties.
117   isSymmetric_ = (strcmp(toLower(symmetry), "symmetric") == 0);
118   // Make sure this is a general sparse matrix.
119   if (strcmp(toLower(header), "%%matrixmarket") ||
120       strcmp(toLower(object), "matrix") ||
121       strcmp(toLower(format), "coordinate") ||
122       (strcmp(toLower(symmetry), "general") && !isSymmetric_))
123     MLIR_SPARSETENSOR_FATAL("Cannot find a general sparse matrix in %s\n",
124                             filename);
125   // Skip comments.
126   while (true) {
127     readLine();
128     if (line[0] != '%')
129       break;
130   }
131   // Next line contains M N NNZ.
132   idata[0] = 2; // rank
133   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
134              idata + 1) != 3)
135     MLIR_SPARSETENSOR_FATAL("Cannot find size in %s\n", filename);
136 }
137 
138 /// Read the "extended" FROSTT header. Although not part of the documented
139 /// format, we assume that the file starts with optional comments followed
140 /// by two lines that define the rank, the number of nonzeros, and the
141 /// dimensions sizes (one per rank) of the sparse tensor.
142 void SparseTensorFile::readExtFROSTTHeader() {
143   // Skip comments.
144   while (true) {
145     readLine();
146     if (line[0] != '#')
147       break;
148   }
149   // Next line contains RANK and NNZ.
150   if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
151     MLIR_SPARSETENSOR_FATAL("Cannot find metadata in %s\n", filename);
152   // Followed by a line with the dimension sizes (one per rank).
153   for (uint64_t r = 0; r < idata[0]; ++r)
154     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
155       MLIR_SPARSETENSOR_FATAL("Cannot find dimension size %s\n", filename);
156   readLine(); // end of line
157   // The FROSTT format does not define the data type of the nonzero elements.
158   valueKind_ = ValueKind::kUndefined;
159 }
160