xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp (revision 2f52bbeb6f6f3b7abef19cb5297773d95aa0b434)
1 //===- LvlTypeParser.h - `LevelType` parser ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "LvlTypeParser.h"
10 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
11 
12 using namespace mlir;
13 using namespace mlir::sparse_tensor;
14 using namespace mlir::sparse_tensor::ir_detail;
15 
16 //===----------------------------------------------------------------------===//
17 #define FAILURE_IF_FAILED(STMT)                                                \
18   if (failed(STMT)) {                                                          \
19     return failure();                                                          \
20   }
21 
22 // NOTE: this macro assumes `AsmParser parser` and `SMLoc loc` are in scope.
23 #define ERROR_IF(COND, MSG)                                                    \
24   if (COND) {                                                                  \
25     return parser.emitError(loc, MSG);                                         \
26   }
27 
28 //===----------------------------------------------------------------------===//
29 // `LvlTypeParser` implementation.
30 //===----------------------------------------------------------------------===//
31 
parseLvlType(AsmParser & parser) const32 FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
33   StringRef base;
34   const auto loc = parser.getCurrentLocation();
35   ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
36            "expected valid level format (e.g. dense, compressed or singleton)")
37   uint64_t properties = 0;
38   SmallVector<unsigned> structured;
39 
40   if (base == "structured") {
41     ParseResult res = parser.parseCommaSeparatedList(
42         mlir::OpAsmParser::Delimiter::OptionalSquare,
43         [&]() -> ParseResult { return parseStructured(parser, &structured); },
44         " in structured n out of m");
45     FAILURE_IF_FAILED(res)
46     if (structured.size() != 2) {
47       parser.emitError(loc, "expected exactly 2 structured sizes");
48       return failure();
49     }
50     if (structured[0] > structured[1]) {
51       parser.emitError(loc, "expected n <= m in n_out_of_m");
52       return failure();
53     }
54   }
55 
56   ParseResult res = parser.parseCommaSeparatedList(
57       mlir::OpAsmParser::Delimiter::OptionalParen,
58       [&]() -> ParseResult { return parseProperty(parser, &properties); },
59       " in level property list");
60   FAILURE_IF_FAILED(res)
61 
62   // Set the base bit for properties.
63   if (base == "dense") {
64     properties |= static_cast<uint64_t>(LevelFormat::Dense);
65   } else if (base == "batch") {
66     properties |= static_cast<uint64_t>(LevelFormat::Batch);
67   } else if (base == "compressed") {
68     properties |= static_cast<uint64_t>(LevelFormat::Compressed);
69   } else if (base == "structured") {
70     properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
71     properties |= nToBits(structured[0]) | mToBits(structured[1]);
72   } else if (base == "loose_compressed") {
73     properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed);
74   } else if (base == "singleton") {
75     properties |= static_cast<uint64_t>(LevelFormat::Singleton);
76   } else {
77     parser.emitError(loc, "unknown level format: ") << base;
78     return failure();
79   }
80 
81   ERROR_IF(!isValidLT(static_cast<LevelType>(properties)),
82            "invalid level type: level format doesn't support the properties");
83   return properties;
84 }
85 
parseProperty(AsmParser & parser,uint64_t * properties) const86 ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
87                                          uint64_t *properties) const {
88   StringRef strVal;
89   auto loc = parser.getCurrentLocation();
90   ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
91            "expected valid level property (e.g. nonordered, nonunique or high)")
92   if (strVal == toPropString(LevelPropNonDefault::Nonunique)) {
93     *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonunique);
94   } else if (strVal == toPropString(LevelPropNonDefault::Nonordered)) {
95     *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonordered);
96   } else if (strVal == toPropString(LevelPropNonDefault::SoA)) {
97     *properties |= static_cast<uint64_t>(LevelPropNonDefault::SoA);
98   } else {
99     parser.emitError(loc, "unknown level property: ") << strVal;
100     return failure();
101   }
102   return success();
103 }
104 
105 ParseResult
parseStructured(AsmParser & parser,SmallVector<unsigned> * structured) const106 LvlTypeParser::parseStructured(AsmParser &parser,
107                                SmallVector<unsigned> *structured) const {
108   int intVal;
109   auto loc = parser.getCurrentLocation();
110   OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);
111   if (intValParseResult.has_value()) {
112     if (failed(*intValParseResult)) {
113       parser.emitError(loc, "failed to parse structured size");
114       return failure();
115     }
116     if (intVal < 0) {
117       parser.emitError(loc, "expected structured size to be >= 0");
118       return failure();
119     }
120     structured->push_back(intVal);
121     return success();
122   }
123   parser.emitError(loc, "expected valid integer for structured size");
124   return failure();
125 }
126 
127 //===----------------------------------------------------------------------===//
128