xref: /llvm-project/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp (revision c1fa60b4cde512964544ab66404dea79dbc5dcb4)
1 //===- PDLTypes.cpp - Pattern Descriptor Language Types -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
10 #include "mlir/Dialect/PDL/IR/PDL.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "llvm/ADT/TypeSwitch.h"
14 
15 using namespace mlir;
16 using namespace mlir::pdl;
17 
18 //===----------------------------------------------------------------------===//
19 // TableGen'd type method definitions
20 //===----------------------------------------------------------------------===//
21 
22 #define GET_TYPEDEF_CLASSES
23 #include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc"
24 
25 //===----------------------------------------------------------------------===//
26 // PDLDialect
27 //===----------------------------------------------------------------------===//
28 
registerTypes()29 void PDLDialect::registerTypes() {
30   addTypes<
31 #define GET_TYPEDEF_LIST
32 #include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc"
33       >();
34 }
35 
parsePDLType(AsmParser & parser)36 static Type parsePDLType(AsmParser &parser) {
37   StringRef typeTag;
38   {
39     Type genType;
40     auto parseResult = generatedTypeParser(parser, &typeTag, genType);
41     if (parseResult.has_value())
42       return genType;
43   }
44 
45   // FIXME: This ends up with a double error being emitted if `RangeType` also
46   // emits an error. We should rework the `generatedTypeParser` to better
47   // support when the keyword is valid but the individual type parser itself
48   // emits an error.
49   parser.emitError(parser.getNameLoc(), "invalid 'pdl' type: `")
50       << typeTag << "'";
51   return Type();
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // PDL Types
56 //===----------------------------------------------------------------------===//
57 
classof(Type type)58 bool PDLType::classof(Type type) {
59   return llvm::isa<PDLDialect>(type.getDialect());
60 }
61 
getRangeElementTypeOrSelf(Type type)62 Type pdl::getRangeElementTypeOrSelf(Type type) {
63   if (auto rangeType = llvm::dyn_cast<RangeType>(type))
64     return rangeType.getElementType();
65   return type;
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // RangeType
70 //===----------------------------------------------------------------------===//
71 
parse(AsmParser & parser)72 Type RangeType::parse(AsmParser &parser) {
73   if (parser.parseLess())
74     return Type();
75 
76   SMLoc elementLoc = parser.getCurrentLocation();
77   Type elementType = parsePDLType(parser);
78   if (!elementType || parser.parseGreater())
79     return Type();
80 
81   if (llvm::isa<RangeType>(elementType)) {
82     parser.emitError(elementLoc)
83         << "element of pdl.range cannot be another range, but got"
84         << elementType;
85     return Type();
86   }
87   return RangeType::get(elementType);
88 }
89 
print(AsmPrinter & printer) const90 void RangeType::print(AsmPrinter &printer) const {
91   printer << "<";
92   (void)generatedTypePrinter(getElementType(), printer);
93   printer << ">";
94 }
95 
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType)96 LogicalResult RangeType::verify(function_ref<InFlightDiagnostic()> emitError,
97                                 Type elementType) {
98   if (!llvm::isa<PDLType>(elementType) || llvm::isa<RangeType>(elementType)) {
99     return emitError()
100            << "expected element of pdl.range to be one of [!pdl.attribute, "
101               "!pdl.operation, !pdl.type, !pdl.value], but got "
102            << elementType;
103   }
104   return success();
105 }
106