1 //===- LvlTypeParser.h - `LevelType` parser ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "LvlTypeParser.h"
10 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
11
12 using namespace mlir;
13 using namespace mlir::sparse_tensor;
14 using namespace mlir::sparse_tensor::ir_detail;
15
16 //===----------------------------------------------------------------------===//
17 #define FAILURE_IF_FAILED(STMT) \
18 if (failed(STMT)) { \
19 return failure(); \
20 }
21
22 // NOTE: this macro assumes `AsmParser parser` and `SMLoc loc` are in scope.
23 #define ERROR_IF(COND, MSG) \
24 if (COND) { \
25 return parser.emitError(loc, MSG); \
26 }
27
28 //===----------------------------------------------------------------------===//
29 // `LvlTypeParser` implementation.
30 //===----------------------------------------------------------------------===//
31
parseLvlType(AsmParser & parser) const32 FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
33 StringRef base;
34 const auto loc = parser.getCurrentLocation();
35 ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
36 "expected valid level format (e.g. dense, compressed or singleton)")
37 uint64_t properties = 0;
38 SmallVector<unsigned> structured;
39
40 if (base == "structured") {
41 ParseResult res = parser.parseCommaSeparatedList(
42 mlir::OpAsmParser::Delimiter::OptionalSquare,
43 [&]() -> ParseResult { return parseStructured(parser, &structured); },
44 " in structured n out of m");
45 FAILURE_IF_FAILED(res)
46 if (structured.size() != 2) {
47 parser.emitError(loc, "expected exactly 2 structured sizes");
48 return failure();
49 }
50 if (structured[0] > structured[1]) {
51 parser.emitError(loc, "expected n <= m in n_out_of_m");
52 return failure();
53 }
54 }
55
56 ParseResult res = parser.parseCommaSeparatedList(
57 mlir::OpAsmParser::Delimiter::OptionalParen,
58 [&]() -> ParseResult { return parseProperty(parser, &properties); },
59 " in level property list");
60 FAILURE_IF_FAILED(res)
61
62 // Set the base bit for properties.
63 if (base == "dense") {
64 properties |= static_cast<uint64_t>(LevelFormat::Dense);
65 } else if (base == "batch") {
66 properties |= static_cast<uint64_t>(LevelFormat::Batch);
67 } else if (base == "compressed") {
68 properties |= static_cast<uint64_t>(LevelFormat::Compressed);
69 } else if (base == "structured") {
70 properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
71 properties |= nToBits(structured[0]) | mToBits(structured[1]);
72 } else if (base == "loose_compressed") {
73 properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed);
74 } else if (base == "singleton") {
75 properties |= static_cast<uint64_t>(LevelFormat::Singleton);
76 } else {
77 parser.emitError(loc, "unknown level format: ") << base;
78 return failure();
79 }
80
81 ERROR_IF(!isValidLT(static_cast<LevelType>(properties)),
82 "invalid level type: level format doesn't support the properties");
83 return properties;
84 }
85
parseProperty(AsmParser & parser,uint64_t * properties) const86 ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
87 uint64_t *properties) const {
88 StringRef strVal;
89 auto loc = parser.getCurrentLocation();
90 ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
91 "expected valid level property (e.g. nonordered, nonunique or high)")
92 if (strVal == toPropString(LevelPropNonDefault::Nonunique)) {
93 *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonunique);
94 } else if (strVal == toPropString(LevelPropNonDefault::Nonordered)) {
95 *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonordered);
96 } else if (strVal == toPropString(LevelPropNonDefault::SoA)) {
97 *properties |= static_cast<uint64_t>(LevelPropNonDefault::SoA);
98 } else {
99 parser.emitError(loc, "unknown level property: ") << strVal;
100 return failure();
101 }
102 return success();
103 }
104
105 ParseResult
parseStructured(AsmParser & parser,SmallVector<unsigned> * structured) const106 LvlTypeParser::parseStructured(AsmParser &parser,
107 SmallVector<unsigned> *structured) const {
108 int intVal;
109 auto loc = parser.getCurrentLocation();
110 OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);
111 if (intValParseResult.has_value()) {
112 if (failed(*intValParseResult)) {
113 parser.emitError(loc, "failed to parse structured size");
114 return failure();
115 }
116 if (intVal < 0) {
117 parser.emitError(loc, "expected structured size to be >= 0");
118 return failure();
119 }
120 structured->push_back(intVal);
121 return success();
122 }
123 parser.emitError(loc, "expected valid integer for structured size");
124 return failure();
125 }
126
127 //===----------------------------------------------------------------------===//
128