1 //===- DialectImplementation.h ----------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains utilities classes for implementing dialect attributes and 10 // types. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_IR_DIALECTIMPLEMENTATION_H 15 #define MLIR_IR_DIALECTIMPLEMENTATION_H 16 17 #include "mlir/IR/OpImplementation.h" 18 #include <type_traits> 19 20 namespace { 21 22 // reference https://stackoverflow.com/a/16000226 23 template <typename T, typename = void> 24 struct HasStaticDialectName : std::false_type {}; 25 26 template <typename T> 27 struct HasStaticDialectName< 28 T, typename std::enable_if< 29 std::is_same<::llvm::StringLiteral, 30 std::decay_t<decltype(T::dialectName)>>::value, 31 void>::type> : std::true_type {}; 32 33 } // namespace 34 35 namespace mlir { 36 37 //===----------------------------------------------------------------------===// 38 // DialectAsmPrinter 39 //===----------------------------------------------------------------------===// 40 41 /// This is a pure-virtual base class that exposes the asmprinter hooks 42 /// necessary to implement a custom printAttribute/printType() method on a 43 /// dialect. 44 class DialectAsmPrinter : public AsmPrinter { 45 public: 46 using AsmPrinter::AsmPrinter; 47 ~DialectAsmPrinter() override; 48 }; 49 50 //===----------------------------------------------------------------------===// 51 // DialectAsmParser 52 //===----------------------------------------------------------------------===// 53 54 /// The DialectAsmParser has methods for interacting with the asm parser when 55 /// parsing attributes and types. 56 class DialectAsmParser : public AsmParser { 57 public: 58 using AsmParser::AsmParser; 59 ~DialectAsmParser() override; 60 61 /// Returns the full specification of the symbol being parsed. This allows for 62 /// using a separate parser if necessary. 63 virtual StringRef getFullSymbolSpec() const = 0; 64 }; 65 66 //===----------------------------------------------------------------------===// 67 // Parse Fields 68 //===----------------------------------------------------------------------===// 69 70 /// Provide a template class that can be specialized by users to dispatch to 71 /// parsers. Auto-generated parsers generate calls to `FieldParser<T>::parse`, 72 /// where `T` is the parameter storage type, to parse custom types. 73 template <typename T, typename = T> 74 struct FieldParser; 75 76 /// Parse an attribute. 77 template <typename AttributeT> 78 struct FieldParser< 79 AttributeT, std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value, 80 AttributeT>> { 81 static FailureOr<AttributeT> parse(AsmParser &parser) { 82 if constexpr (HasStaticDialectName<AttributeT>::value) { 83 parser.getContext()->getOrLoadDialect(AttributeT::dialectName); 84 } 85 AttributeT value; 86 if (parser.parseCustomAttributeWithFallback(value)) 87 return failure(); 88 return value; 89 } 90 }; 91 92 /// Parse a type. 93 template <typename TypeT> 94 struct FieldParser< 95 TypeT, std::enable_if_t<std::is_base_of<Type, TypeT>::value, TypeT>> { 96 static FailureOr<TypeT> parse(AsmParser &parser) { 97 TypeT value; 98 if (parser.parseCustomTypeWithFallback(value)) 99 return failure(); 100 return value; 101 } 102 }; 103 104 /// Parse any integer. 105 template <typename IntT> 106 struct FieldParser<IntT, 107 std::enable_if_t<std::is_integral<IntT>::value, IntT>> { 108 static FailureOr<IntT> parse(AsmParser &parser) { 109 IntT value = 0; 110 if (parser.parseInteger(value)) 111 return failure(); 112 return value; 113 } 114 }; 115 116 /// Parse a string. 117 template <> 118 struct FieldParser<std::string> { 119 static FailureOr<std::string> parse(AsmParser &parser) { 120 std::string value; 121 if (parser.parseString(&value)) 122 return failure(); 123 return value; 124 } 125 }; 126 127 /// Parse an Optional attribute. 128 template <typename AttributeT> 129 struct FieldParser< 130 std::optional<AttributeT>, 131 std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value, 132 std::optional<AttributeT>>> { 133 static FailureOr<std::optional<AttributeT>> parse(AsmParser &parser) { 134 if constexpr (HasStaticDialectName<AttributeT>::value) { 135 parser.getContext()->getOrLoadDialect(AttributeT::dialectName); 136 } 137 AttributeT attr; 138 OptionalParseResult result = parser.parseOptionalAttribute(attr); 139 if (result.has_value()) { 140 if (succeeded(*result)) 141 return {std::optional<AttributeT>(attr)}; 142 return failure(); 143 } 144 return {std::nullopt}; 145 } 146 }; 147 148 /// Parse an Optional integer. 149 template <typename IntT> 150 struct FieldParser< 151 std::optional<IntT>, 152 std::enable_if_t<std::is_integral<IntT>::value, std::optional<IntT>>> { 153 static FailureOr<std::optional<IntT>> parse(AsmParser &parser) { 154 IntT value; 155 OptionalParseResult result = parser.parseOptionalInteger(value); 156 if (result.has_value()) { 157 if (succeeded(*result)) 158 return {std::optional<IntT>(value)}; 159 return failure(); 160 } 161 return {std::nullopt}; 162 } 163 }; 164 165 namespace detail { 166 template <typename T> 167 using has_push_back_t = decltype(std::declval<T>().push_back( 168 std::declval<typename T::value_type &&>())); 169 } // namespace detail 170 171 /// Parse any container that supports back insertion as a list. 172 template <typename ContainerT> 173 struct FieldParser<ContainerT, 174 std::enable_if_t<llvm::is_detected<detail::has_push_back_t, 175 ContainerT>::value, 176 ContainerT>> { 177 using ElementT = typename ContainerT::value_type; 178 static FailureOr<ContainerT> parse(AsmParser &parser) { 179 ContainerT elements; 180 auto elementParser = [&]() { 181 auto element = FieldParser<ElementT>::parse(parser); 182 if (failed(element)) 183 return failure(); 184 elements.push_back(std::move(*element)); 185 return success(); 186 }; 187 if (parser.parseCommaSeparatedList(elementParser)) 188 return failure(); 189 return elements; 190 } 191 }; 192 193 /// Parse an affine map. 194 template <> 195 struct FieldParser<AffineMap> { 196 static FailureOr<AffineMap> parse(AsmParser &parser) { 197 AffineMap map; 198 if (failed(parser.parseAffineMap(map))) 199 return failure(); 200 return map; 201 } 202 }; 203 204 } // namespace mlir 205 206 #endif // MLIR_IR_DIALECTIMPLEMENTATION_H 207