xref: /llvm-project/mlir/include/mlir/IR/DialectImplementation.h (revision d0b7633d7ad566579bfb794f95cce9aef294c92b)
1 //===- DialectImplementation.h ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains utilities classes for implementing dialect attributes and
10 // types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_IR_DIALECTIMPLEMENTATION_H
15 #define MLIR_IR_DIALECTIMPLEMENTATION_H
16 
17 #include "mlir/IR/OpImplementation.h"
18 #include <type_traits>
19 
20 namespace {
21 
22 // reference https://stackoverflow.com/a/16000226
23 template <typename T, typename = void>
24 struct HasStaticDialectName : std::false_type {};
25 
26 template <typename T>
27 struct HasStaticDialectName<
28     T, typename std::enable_if<
29            std::is_same<::llvm::StringLiteral,
30                         std::decay_t<decltype(T::dialectName)>>::value,
31            void>::type> : std::true_type {};
32 
33 } // namespace
34 
35 namespace mlir {
36 
37 //===----------------------------------------------------------------------===//
38 // DialectAsmPrinter
39 //===----------------------------------------------------------------------===//
40 
41 /// This is a pure-virtual base class that exposes the asmprinter hooks
42 /// necessary to implement a custom printAttribute/printType() method on a
43 /// dialect.
44 class DialectAsmPrinter : public AsmPrinter {
45 public:
46   using AsmPrinter::AsmPrinter;
47   ~DialectAsmPrinter() override;
48 };
49 
50 //===----------------------------------------------------------------------===//
51 // DialectAsmParser
52 //===----------------------------------------------------------------------===//
53 
54 /// The DialectAsmParser has methods for interacting with the asm parser when
55 /// parsing attributes and types.
56 class DialectAsmParser : public AsmParser {
57 public:
58   using AsmParser::AsmParser;
59   ~DialectAsmParser() override;
60 
61   /// Returns the full specification of the symbol being parsed. This allows for
62   /// using a separate parser if necessary.
63   virtual StringRef getFullSymbolSpec() const = 0;
64 };
65 
66 //===----------------------------------------------------------------------===//
67 // Parse Fields
68 //===----------------------------------------------------------------------===//
69 
70 /// Provide a template class that can be specialized by users to dispatch to
71 /// parsers. Auto-generated parsers generate calls to `FieldParser<T>::parse`,
72 /// where `T` is the parameter storage type, to parse custom types.
73 template <typename T, typename = T>
74 struct FieldParser;
75 
76 /// Parse an attribute.
77 template <typename AttributeT>
78 struct FieldParser<
79     AttributeT, std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
80                                  AttributeT>> {
81   static FailureOr<AttributeT> parse(AsmParser &parser) {
82     if constexpr (HasStaticDialectName<AttributeT>::value) {
83       parser.getContext()->getOrLoadDialect(AttributeT::dialectName);
84     }
85     AttributeT value;
86     if (parser.parseCustomAttributeWithFallback(value))
87       return failure();
88     return value;
89   }
90 };
91 
92 /// Parse a type.
93 template <typename TypeT>
94 struct FieldParser<
95     TypeT, std::enable_if_t<std::is_base_of<Type, TypeT>::value, TypeT>> {
96   static FailureOr<TypeT> parse(AsmParser &parser) {
97     TypeT value;
98     if (parser.parseCustomTypeWithFallback(value))
99       return failure();
100     return value;
101   }
102 };
103 
104 /// Parse any integer.
105 template <typename IntT>
106 struct FieldParser<IntT,
107                    std::enable_if_t<std::is_integral<IntT>::value, IntT>> {
108   static FailureOr<IntT> parse(AsmParser &parser) {
109     IntT value = 0;
110     if (parser.parseInteger(value))
111       return failure();
112     return value;
113   }
114 };
115 
116 /// Parse a string.
117 template <>
118 struct FieldParser<std::string> {
119   static FailureOr<std::string> parse(AsmParser &parser) {
120     std::string value;
121     if (parser.parseString(&value))
122       return failure();
123     return value;
124   }
125 };
126 
127 /// Parse an Optional attribute.
128 template <typename AttributeT>
129 struct FieldParser<
130     std::optional<AttributeT>,
131     std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
132                      std::optional<AttributeT>>> {
133   static FailureOr<std::optional<AttributeT>> parse(AsmParser &parser) {
134     if constexpr (HasStaticDialectName<AttributeT>::value) {
135       parser.getContext()->getOrLoadDialect(AttributeT::dialectName);
136     }
137     AttributeT attr;
138     OptionalParseResult result = parser.parseOptionalAttribute(attr);
139     if (result.has_value()) {
140       if (succeeded(*result))
141         return {std::optional<AttributeT>(attr)};
142       return failure();
143     }
144     return {std::nullopt};
145   }
146 };
147 
148 /// Parse an Optional integer.
149 template <typename IntT>
150 struct FieldParser<
151     std::optional<IntT>,
152     std::enable_if_t<std::is_integral<IntT>::value, std::optional<IntT>>> {
153   static FailureOr<std::optional<IntT>> parse(AsmParser &parser) {
154     IntT value;
155     OptionalParseResult result = parser.parseOptionalInteger(value);
156     if (result.has_value()) {
157       if (succeeded(*result))
158         return {std::optional<IntT>(value)};
159       return failure();
160     }
161     return {std::nullopt};
162   }
163 };
164 
165 namespace detail {
166 template <typename T>
167 using has_push_back_t = decltype(std::declval<T>().push_back(
168     std::declval<typename T::value_type &&>()));
169 } // namespace detail
170 
171 /// Parse any container that supports back insertion as a list.
172 template <typename ContainerT>
173 struct FieldParser<ContainerT,
174                    std::enable_if_t<llvm::is_detected<detail::has_push_back_t,
175                                                       ContainerT>::value,
176                                     ContainerT>> {
177   using ElementT = typename ContainerT::value_type;
178   static FailureOr<ContainerT> parse(AsmParser &parser) {
179     ContainerT elements;
180     auto elementParser = [&]() {
181       auto element = FieldParser<ElementT>::parse(parser);
182       if (failed(element))
183         return failure();
184       elements.push_back(std::move(*element));
185       return success();
186     };
187     if (parser.parseCommaSeparatedList(elementParser))
188       return failure();
189     return elements;
190   }
191 };
192 
193 /// Parse an affine map.
194 template <>
195 struct FieldParser<AffineMap> {
196   static FailureOr<AffineMap> parse(AsmParser &parser) {
197     AffineMap map;
198     if (failed(parser.parseAffineMap(map)))
199       return failure();
200     return map;
201   }
202 };
203 
204 } // namespace mlir
205 
206 #endif // MLIR_IR_DIALECTIMPLEMENTATION_H
207