1 //===- File.cpp - Reading/writing sparse tensors from/to files ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements reading and writing sparse tensor files. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/ExecutionEngine/SparseTensor/File.h" 14 15 #include <cctype> 16 #include <cstring> 17 18 using namespace mlir::sparse_tensor; 19 20 /// Opens the file for reading. 21 void SparseTensorReader::openFile() { 22 if (file) 23 MLIR_SPARSETENSOR_FATAL("Already opened file %s\n", filename); 24 file = fopen(filename, "r"); 25 if (!file) 26 MLIR_SPARSETENSOR_FATAL("Cannot find file %s\n", filename); 27 } 28 29 /// Closes the file. 30 void SparseTensorReader::closeFile() { 31 if (file) { 32 fclose(file); 33 file = nullptr; 34 } 35 } 36 37 /// Attempts to read a line from the file. 38 void SparseTensorReader::readLine() { 39 if (!fgets(line, kColWidth, file)) 40 MLIR_SPARSETENSOR_FATAL("Cannot read next line of %s\n", filename); 41 } 42 43 /// Reads and parses the file's header. 44 void SparseTensorReader::readHeader() { 45 assert(file && "Attempt to readHeader() before openFile()"); 46 if (strstr(filename, ".mtx")) 47 readMMEHeader(); 48 else if (strstr(filename, ".tns")) 49 readExtFROSTTHeader(); 50 else 51 MLIR_SPARSETENSOR_FATAL("Unknown format %s\n", filename); 52 assert(isValid() && "Failed to read the header"); 53 } 54 55 /// Asserts the shape subsumes the actual dimension sizes. Is only 56 /// valid after parsing the header. 57 void SparseTensorReader::assertMatchesShape(uint64_t rank, 58 const uint64_t *shape) const { 59 assert(rank == getRank() && "Rank mismatch"); 60 for (uint64_t r = 0; r < rank; ++r) 61 assert((shape[r] == 0 || shape[r] == idata[2 + r]) && 62 "Dimension size mismatch"); 63 } 64 65 bool SparseTensorReader::canReadAs(PrimaryType valTy) const { 66 switch (valueKind_) { 67 case ValueKind::kInvalid: 68 assert(false && "Must readHeader() before calling canReadAs()"); 69 return false; // In case assertions are disabled. 70 case ValueKind::kPattern: 71 return true; 72 case ValueKind::kInteger: 73 // When the file is specified to store integer values, we still 74 // allow implicitly converting those to floating primary-types. 75 return isRealPrimaryType(valTy); 76 case ValueKind::kReal: 77 // When the file is specified to store real/floating values, then 78 // we disallow implicit conversion to integer primary-types. 79 return isFloatingPrimaryType(valTy); 80 case ValueKind::kComplex: 81 // When the file is specified to store complex values, then we 82 // require a complex primary-type. 83 return isComplexPrimaryType(valTy); 84 case ValueKind::kUndefined: 85 // The "extended" FROSTT format doesn't specify a ValueKind. 86 // So we allow implicitly converting the stored values to both 87 // integer and floating primary-types. 88 return isRealPrimaryType(valTy); 89 } 90 MLIR_SPARSETENSOR_FATAL("Unknown ValueKind: %d\n", 91 static_cast<uint8_t>(valueKind_)); 92 } 93 94 /// Helper to convert C-style strings (i.e., '\0' terminated) to lower case. 95 static inline void toLower(char *token) { 96 for (char *c = token; *c; ++c) 97 *c = tolower(*c); 98 } 99 100 /// Idiomatic name for checking string equality. 101 static inline bool streq(const char *lhs, const char *rhs) { 102 return strcmp(lhs, rhs) == 0; 103 } 104 105 /// Idiomatic name for checking string inequality. 106 static inline bool strne(const char *lhs, const char *rhs) { 107 return strcmp(lhs, rhs); // aka `!= 0` 108 } 109 110 /// Read the MME header of a general sparse matrix of type real. 111 void SparseTensorReader::readMMEHeader() { 112 char header[64]; 113 char object[64]; 114 char format[64]; 115 char field[64]; 116 char symmetry[64]; 117 // Read header line. 118 if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field, 119 symmetry) != 5) 120 MLIR_SPARSETENSOR_FATAL("Corrupt header in %s\n", filename); 121 // Convert all to lowercase up front (to avoid accidental redundancy). 122 toLower(header); 123 toLower(object); 124 toLower(format); 125 toLower(field); 126 toLower(symmetry); 127 // Process `field`, which specify pattern or the data type of the values. 128 if (streq(field, "pattern")) 129 valueKind_ = ValueKind::kPattern; 130 else if (streq(field, "real")) 131 valueKind_ = ValueKind::kReal; 132 else if (streq(field, "integer")) 133 valueKind_ = ValueKind::kInteger; 134 else if (streq(field, "complex")) 135 valueKind_ = ValueKind::kComplex; 136 else 137 MLIR_SPARSETENSOR_FATAL("Unexpected header field value in %s\n", filename); 138 // Set properties. 139 isSymmetric_ = streq(symmetry, "symmetric"); 140 // Make sure this is a general sparse matrix. 141 if (strne(header, "%%matrixmarket") || strne(object, "matrix") || 142 strne(format, "coordinate") || 143 (strne(symmetry, "general") && !isSymmetric_)) 144 MLIR_SPARSETENSOR_FATAL("Cannot find a general sparse matrix in %s\n", 145 filename); 146 // Skip comments. 147 while (true) { 148 readLine(); 149 if (line[0] != '%') 150 break; 151 } 152 // Next line contains M N NNZ. 153 idata[0] = 2; // rank 154 if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3, 155 idata + 1) != 3) 156 MLIR_SPARSETENSOR_FATAL("Cannot find size in %s\n", filename); 157 } 158 159 /// Read the "extended" FROSTT header. Although not part of the documented 160 /// format, we assume that the file starts with optional comments followed 161 /// by two lines that define the rank, the number of nonzeros, and the 162 /// dimensions sizes (one per rank) of the sparse tensor. 163 void SparseTensorReader::readExtFROSTTHeader() { 164 // Skip comments. 165 while (true) { 166 readLine(); 167 if (line[0] != '#') 168 break; 169 } 170 // Next line contains RANK and NNZ. 171 if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) 172 MLIR_SPARSETENSOR_FATAL("Cannot find metadata in %s\n", filename); 173 // Followed by a line with the dimension sizes (one per rank). 174 for (uint64_t r = 0; r < idata[0]; ++r) 175 if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) 176 MLIR_SPARSETENSOR_FATAL("Cannot find dimension size %s\n", filename); 177 readLine(); // end of line 178 // The FROSTT format does not define the data type of the nonzero elements. 179 valueKind_ = ValueKind::kUndefined; 180 } 181